[llvm] r261544 - [attrs] Handle convergent CallSites.

Justin Lebar via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 22 09:51:37 PST 2016


Author: jlebar
Date: Mon Feb 22 11:51:35 2016
New Revision: 261544

URL: http://llvm.org/viewvc/llvm-project?rev=261544&view=rev
Log:
[attrs] Handle convergent CallSites.

Summary:
Previously we had a notion of convergent functions but not of convergent
calls.  This is insufficient to correctly analyze calls where the target
is unknown, e.g. indirect calls.

Now a call is convergent if it targets a known-convergent function, or
if it's explicitly marked as convergent.  As usual, we can remove
convergent where we can prove that no convergent operations are
performed in the call.

Reviewers: chandlerc, jingyue

Subscribers: hfinkel, jhen, tra, llvm-commits

Differential Revision: http://reviews.llvm.org/D17317

Added:
    llvm/trunk/test/Transforms/InstCombine/convergent.ll
Modified:
    llvm/trunk/lib/Transforms/IPO/FunctionAttrs.cpp
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/trunk/test/Transforms/FunctionAttrs/convergent.ll

Modified: llvm/trunk/lib/Transforms/IPO/FunctionAttrs.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/IPO/FunctionAttrs.cpp?rev=261544&r1=261543&r2=261544&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/IPO/FunctionAttrs.cpp (original)
+++ llvm/trunk/lib/Transforms/IPO/FunctionAttrs.cpp Mon Feb 22 11:51:35 2016
@@ -903,49 +903,37 @@ static bool addNonNullAttrs(const SCCNod
   return MadeChange;
 }
 
-/// Removes convergent attributes where we can prove that none of the SCC's
-/// callees are themselves convergent.  Returns true if successful at removing
-/// the attribute.
+/// Remove the convergent attribute from all functions in the SCC if every
+/// callsite within the SCC is not convergent (except for calls to functions
+/// within the SCC).  Returns true if changes were made.
 static bool removeConvergentAttrs(const SCCNodeSet &SCCNodes) {
-  // Determines whether a function can be made non-convergent, ignoring all
-  // other functions in SCC.  (A function can *actually* be made non-convergent
-  // only if all functions in its SCC can be made convergent.)
-  auto CanRemoveConvergent = [&](Function *F) {
-    if (!F->isConvergent())
-      return true;
-
-    // Can't remove convergent from declarations.
-    if (F->isDeclaration())
-      return false;
-
-    for (Instruction &I : instructions(*F))
-      if (auto CS = CallSite(&I)) {
-        // Can't remove convergent if any of F's callees -- ignoring functions
-        // in the SCC itself -- are convergent. This needs to consider both
-        // function calls and intrinsic calls. We also assume indirect calls
-        // might call a convergent function.
-        // FIXME: We should revisit this when we put convergent onto calls
-        // instead of functions so that indirect calls which should be
-        // convergent are required to be marked as such.
-        Function *Callee = CS.getCalledFunction();
-        if (!Callee || (SCCNodes.count(Callee) == 0 && Callee->isConvergent()))
-          return false;
-      }
-
-    return true;
-  };
+  // No point checking if none of SCCNodes is convergent.
+  if (!llvm::any_of(SCCNodes, [](Function *F) { return F->isConvergent(); }))
+    return false;
 
-  // We can remove the convergent attr from functions in the SCC if they all
-  // can be made non-convergent (because they call only non-convergent
-  // functions, other than each other).
-  if (!llvm::all_of(SCCNodes, CanRemoveConvergent))
+  // Can't remove convergent from function declarations.
+  if (llvm::any_of(SCCNodes, [](Function *F) { return F->isDeclaration(); }))
     return false;
 
-  // If we got here, all of the SCC's callees are non-convergent. Therefore all
-  // of the SCC's functions can be marked as non-convergent.
+  // Can't remove convergent if any of our functions has a convergent call to a
+  // function not in the SCC.
+  for (Function *F : SCCNodes)
+    for (Instruction &I : instructions(*F)) {
+      CallSite CS(&I);
+      // Bail if is CS a convergent call to a function not in the SCC.
+      if (CS && CS.isConvergent() &&
+          SCCNodes.count(CS.getCalledFunction()) == 0)
+        return false;
+    }
+
+  // If we got here, all of the calls the SCC makes to functions not in the SCC
+  // are non-convergent. Therefore all of the SCC's functions can also be made
+  // non-convergent.  We'll remove the attr from the callsites in
+  // InstCombineCalls.
   for (Function *F : SCCNodes) {
     if (F->isConvergent())
-      DEBUG(dbgs() << "Removing convergent attr from " << F->getName() << "\n");
+      DEBUG(dbgs() << "Removing convergent attr from fn " << F->getName()
+                   << "\n");
     F->setNotConvergent();
   }
   return true;

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp?rev=261544&r1=261543&r2=261544&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp Mon Feb 22 11:51:35 2016
@@ -2070,7 +2070,15 @@ Instruction *InstCombiner::visitCallSite
   if (!isa<Function>(Callee) && transformConstExprCastCall(CS))
     return nullptr;
 
-  if (Function *CalleeF = dyn_cast<Function>(Callee))
+  if (Function *CalleeF = dyn_cast<Function>(Callee)) {
+    // Remove the convergent attr on calls when the callee is not convergent.
+    if (CS.isConvergent() && !CalleeF->isConvergent()) {
+      DEBUG(dbgs() << "Removing convergent attr from instr "
+                   << CS.getInstruction() << "\n");
+      CS.setNotConvergent();
+      return CS.getInstruction();
+    }
+
     // If the call and callee calling conventions don't match, this call must
     // be unreachable, as the call is undefined.
     if (CalleeF->getCallingConv() != CS.getCallingConv() &&
@@ -2095,6 +2103,7 @@ Instruction *InstCombiner::visitCallSite
                                     Constant::getNullValue(CalleeF->getType()));
       return nullptr;
     }
+  }
 
   if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) {
     // If CS does not return void then replaceAllUsesWith undef.

Modified: llvm/trunk/test/Transforms/FunctionAttrs/convergent.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/FunctionAttrs/convergent.ll?rev=261544&r1=261543&r2=261544&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/FunctionAttrs/convergent.ll (original)
+++ llvm/trunk/test/Transforms/FunctionAttrs/convergent.ll Mon Feb 22 11:51:35 2016
@@ -1,4 +1,4 @@
-; RUN: opt < %s -basicaa -functionattrs -rpo-functionattrs -S | FileCheck %s
+; RUN: opt -functionattrs -S < %s | FileCheck %s
 
 ; CHECK: Function Attrs
 ; CHECK-NOT: convergent
@@ -24,20 +24,41 @@ declare i32 @k() convergent
 ; CHECK-SAME: convergent
 ; CHECK-NEXT: define i32 @extern()
 define i32 @extern() convergent {
-  %a = call i32 @k()
+  %a = call i32 @k() convergent
   ret i32 %a
 }
 
+; Convergent should not be removed on the function here.  Although the call is
+; not explicitly convergent, it picks up the convergent attr from the callee.
+;
 ; CHECK: Function Attrs
 ; CHECK-SAME: convergent
-; CHECK-NEXT: define i32 @call_extern()
-define i32 @call_extern() convergent {
-  %a = call i32 @extern()
+; CHECK-NEXT: define i32 @extern_non_convergent_call()
+define i32 @extern_non_convergent_call() convergent {
+  %a = call i32 @k()
   ret i32 %a
 }
 
 ; CHECK: Function Attrs
 ; CHECK-SAME: convergent
+; CHECK-NEXT: define i32 @indirect_convergent_call(
+define i32 @indirect_convergent_call(i32 ()* %f) convergent {
+   %a = call i32 %f() convergent
+   ret i32 %a
+}
+; Give indirect_non_convergent_call the norecurse attribute so we get a
+; "Function Attrs" comment in the output.
+;
+; CHECK: Function Attrs
+; CHECK-NOT: convergent
+; CHECK-NEXT: define i32 @indirect_non_convergent_call(
+define i32 @indirect_non_convergent_call(i32 ()* %f) convergent norecurse {
+   %a = call i32 %f()
+   ret i32 %a
+}
+
+; CHECK: Function Attrs
+; CHECK-SAME: convergent
 ; CHECK-NEXT: declare void @llvm.cuda.syncthreads()
 declare void @llvm.cuda.syncthreads() convergent
 
@@ -45,25 +66,16 @@ declare void @llvm.cuda.syncthreads() co
 ; CHECK-SAME: convergent
 ; CHECK-NEXT: define i32 @intrinsic()
 define i32 @intrinsic() convergent {
+  ; Implicitly convergent, because the intrinsic is convergent.
   call void @llvm.cuda.syncthreads()
   ret i32 0
 }
 
- at xyz = global i32 ()* null
-; CHECK: Function Attrs
-; CHECK-SAME: convergent
-; CHECK-NEXT: define i32 @functionptr()
-define i32 @functionptr() convergent {
-  %1 = load i32 ()*, i32 ()** @xyz
-  %2 = call i32 %1()
-  ret i32 %2
-}
-
 ; CHECK: Function Attrs
 ; CHECK-NOT: convergent
 ; CHECK-NEXT: define i32 @recursive1()
 define i32 @recursive1() convergent {
-  %a = call i32 @recursive2()
+  %a = call i32 @recursive2() convergent
   ret i32 %a
 }
 
@@ -71,7 +83,7 @@ define i32 @recursive1() convergent {
 ; CHECK-NOT: convergent
 ; CHECK-NEXT: define i32 @recursive2()
 define i32 @recursive2() convergent {
-  %a = call i32 @recursive1()
+  %a = call i32 @recursive1() convergent
   ret i32 %a
 }
 
@@ -79,7 +91,7 @@ define i32 @recursive2() convergent {
 ; CHECK-SAME: convergent
 ; CHECK-NEXT: define i32 @noopt()
 define i32 @noopt() convergent optnone noinline {
-  %a = call i32 @noopt_friend()
+  %a = call i32 @noopt_friend() convergent
   ret i32 0
 }
 

Added: llvm/trunk/test/Transforms/InstCombine/convergent.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/convergent.ll?rev=261544&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/convergent.ll (added)
+++ llvm/trunk/test/Transforms/InstCombine/convergent.ll Mon Feb 22 11:51:35 2016
@@ -0,0 +1,33 @@
+; RUN: opt -instcombine -S < %s | FileCheck %s
+
+declare i32 @k() convergent
+declare i32 @f()
+
+define i32 @extern() {
+  ; Convergent attr shouldn't be removed here; k is convergent.
+  ; CHECK: call i32 @k() [[CONVERGENT_ATTR:#[0-9]+]]
+  %a = call i32 @k() convergent
+  ret i32 %a
+}
+
+define i32 @extern_no_attr() {
+  ; Convergent attr shouldn't be added here, even though k is convergent.
+  ; CHECK: call i32 @k(){{$}}
+  %a = call i32 @k()
+  ret i32 %a
+}
+
+define i32 @no_extern() {
+  ; Convergent should be removed here, as the target is convergent.
+  ; CHECK: call i32 @f(){{$}}
+  %a = call i32 @f() convergent
+  ret i32 %a
+}
+
+define i32 @indirect_call(i32 ()* %f) {
+  ; CHECK call i32 %f() [[CONVERGENT_ATTR]]
+  %a = call i32 %f() convergent
+  ret i32 %a
+}
+
+; CHECK: [[CONVERGENT_ATTR]] = { convergent }




More information about the llvm-commits mailing list