[llvm] 8dce4c5 - [Inliner] Handle convergence control when inlining a call

Sameer Sahasrabuddhe via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 16 21:26:45 PDT 2023


Author: Sameer Sahasrabuddhe
Date: 2023-08-17T09:56:25+05:30
New Revision: 8dce4c56dd760b34b222d3188a0fe26232928406

URL: https://github.com/llvm/llvm-project/commit/8dce4c56dd760b34b222d3188a0fe26232928406
DIFF: https://github.com/llvm/llvm-project/commit/8dce4c56dd760b34b222d3188a0fe26232928406.diff

LOG: [Inliner] Handle convergence control when inlining a call

When a convergencectrl token is passed to a convergent call, and the called
function in turn calls the entry intrinsic, the intrinsic is now now replaced
with the convergencectrl token.

The spec requires the following check:
  A call from function F to function G can be inlined only if:
  - at least one of F or G does not make any convergent calls, or,
  - both F and G make the same kind of convergent calls: controlled or
    uncontrolled.

But this change does not implement this complete check. A proper implemenation
require a whole new analysis that identifies convergence in every function. For
now, we skip that and just do a cursory check for the entry intrinsic. The
underlying assumption is that in a compiler flow that fully implements
convergence control tokens, there is no mixing of controlled and uncontrolled
convergent operations in the whole program.

This is a reboot of the original change D85606 by
Nicolai Haehnle <nicolai.haehnle at amd.com>.

Reviewed By: arsenm, nhaehnle

Differential Revision: https://reviews.llvm.org/D152431

Added: 
    llvm/test/Transforms/Inline/convergence-inline.ll

Modified: 
    llvm/lib/Transforms/Utils/InlineFunction.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp
index b70236c47415d2..c6d342e26d57e8 100644
--- a/llvm/lib/Transforms/Utils/InlineFunction.cpp
+++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp
@@ -1951,9 +1951,11 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
 
   // The inliner does not know how to inline through calls with operand bundles
   // in general ...
+  Value *ConvergenceControlToken = nullptr;
   if (CB.hasOperandBundles()) {
     for (int i = 0, e = CB.getNumOperandBundles(); i != e; ++i) {
-      uint32_t Tag = CB.getOperandBundleAt(i).getTagID();
+      auto OBUse = CB.getOperandBundleAt(i);
+      uint32_t Tag = OBUse.getTagID();
       // ... but it knows how to inline through "deopt" operand bundles ...
       if (Tag == LLVMContext::OB_deopt)
         continue;
@@ -1964,11 +1966,37 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
         continue;
       if (Tag == LLVMContext::OB_kcfi)
         continue;
+      if (Tag == LLVMContext::OB_convergencectrl) {
+        ConvergenceControlToken = OBUse.Inputs[0].get();
+        continue;
+      }
 
       return InlineResult::failure("unsupported operand bundle");
     }
   }
 
+  // FIXME: The check below is redundant and incomplete. According to spec, if a
+  // convergent call is missing a token, then the caller is using uncontrolled
+  // convergence. If the callee has an entry intrinsic, then the callee is using
+  // controlled convergence, and the call cannot be inlined. A proper
+  // implemenation of this check requires a whole new analysis that identifies
+  // convergence in every function. For now, we skip that and just do this one
+  // cursory check. The underlying assumption is that in a compiler flow that
+  // fully implements convergence control tokens, there is no mixing of
+  // controlled and uncontrolled convergent operations in the whole program.
+  if (CB.isConvergent()) {
+    auto *I = CalledFunc->getEntryBlock().getFirstNonPHI();
+    if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(I)) {
+      if (IntrinsicCall->getIntrinsicID() ==
+          Intrinsic::experimental_convergence_entry) {
+        if (!ConvergenceControlToken) {
+          return InlineResult::failure(
+              "convergent call needs convergencectrl operand");
+        }
+      }
+    }
+  }
+
   // If the call to the callee cannot throw, set the 'nounwind' flag on any
   // calls that we inline.
   bool MarkNoUnwind = CB.doesNotThrow();
@@ -2258,6 +2286,17 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
             IFI.GetAssumptionCache(*Caller).registerAssumption(II);
   }
 
+  if (ConvergenceControlToken) {
+    auto *I = FirstNewBlock->getFirstNonPHI();
+    if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(I)) {
+      if (IntrinsicCall->getIntrinsicID() ==
+          Intrinsic::experimental_convergence_entry) {
+        IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken);
+        IntrinsicCall->eraseFromParent();
+      }
+    }
+  }
+
   // If there are any alloca instructions in the block that used to be the entry
   // block for the callee, move them to the entry block of the caller.  First
   // calculate which instruction they should be inserted before.  We insert the

diff  --git a/llvm/test/Transforms/Inline/convergence-inline.ll b/llvm/test/Transforms/Inline/convergence-inline.ll
new file mode 100644
index 00000000000000..8c67e6a59b7db1
--- /dev/null
+++ b/llvm/test/Transforms/Inline/convergence-inline.ll
@@ -0,0 +1,193 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='cgscc(inline)' -S %s | FileCheck %s
+
+define void @nonconvergent_callee() alwaysinline {
+; CHECK-LABEL: @nonconvergent_callee(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TOKEN:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT:    call void @f(i32 0) [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    ret void
+;
+entry:
+  %token = call token @llvm.experimental.convergence.anchor()
+  call void @f(i32 0) [ "convergencectrl"(token %token) ]
+  ret void
+}
+
+define void @convergent_callee(i32 %v) convergent alwaysinline {
+; CHECK-LABEL: @convergent_callee(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    call void @f(i32 [[V:%.*]]) [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    ret void
+;
+entry:
+  %token = call token @llvm.experimental.convergence.entry()
+  call void @f(i32 %v) [ "convergencectrl"(token %token) ]
+  ret void
+}
+
+define void @test_nonconvergent() {
+; CHECK-LABEL: @test_nonconvergent(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TOKEN_I:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT:    call void @f(i32 0) [ "convergencectrl"(token [[TOKEN_I]]) ]
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @nonconvergent_callee()
+  ret void
+}
+
+define void @test_convergent_basic(i1 %cond) {
+; CHECK-LABEL: @test_convergent_basic(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TOKEN:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT:    br i1 [[COND:%.*]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @f(i32 0) [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %token = call token @llvm.experimental.convergence.anchor()
+  br i1 %cond, label %then, label %end
+
+then:
+  call void @convergent_callee(i32 0) [ "convergencectrl"(token %token) ]
+  br label %end
+
+end:
+  ret void
+}
+
+define void @test_convergent_no_token(i1 %cond) convergent {
+; CHECK-LABEL: @test_convergent_no_token(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @convergent_callee(i32 0)
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @convergent_callee(i32 0)
+  ret void
+}
+
+define void @test_convergent_multiple() convergent {
+; CHECK-LABEL: @test_convergent_multiple(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    call void @f(i32 0) [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    call void @f(i32 1) [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    call void @f(i32 2) [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    ret void
+;
+entry:
+  %token = call token @llvm.experimental.convergence.entry()
+  call void @convergent_callee(i32 0) [ "convergencectrl"(token %token) ]
+  call void @convergent_callee(i32 1) [ "convergencectrl"(token %token) ]
+  call void @convergent_callee(i32 2) [ "convergencectrl"(token %token) ]
+  ret void
+}
+
+define void @test_convergent_loop(i1 %cond) {
+; CHECK-LABEL: @test_convergent_loop(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TOKEN:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT:    br i1 [[COND:%.*]], label [[HDR:%.*]], label [[END:%.*]]
+; CHECK:       hdr:
+; CHECK-NEXT:    [[TOK_LOOP:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    call void @f(i32 0) [ "convergencectrl"(token [[TOK_LOOP]]) ]
+; CHECK-NEXT:    br i1 [[COND]], label [[HDR]], label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %token = call token @llvm.experimental.convergence.anchor()
+  br i1 %cond, label %hdr, label %end
+
+hdr:
+  %tok.loop = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %token) ]
+  call void @convergent_callee(i32 0) [ "convergencectrl"(token %tok.loop) ]
+  br i1 %cond, label %hdr, label %end
+
+end:
+  ret void
+}
+
+define void @make_indirect_call(ptr %f, i32 %x) convergent alwaysinline {
+; CHECK-LABEL: @make_indirect_call(
+; CHECK-NEXT:    [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    call void [[F:%.*]](i32 [[X:%.*]]) #[[ATTR2:[0-9]+]] [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    ret void
+;
+  %token = call token @llvm.experimental.convergence.entry()
+  call void %f(i32 %x) convergent [ "convergencectrl"(token %token) ]
+  ret void
+}
+
+define void @test_indirect_call() convergent {
+; CHECK-LABEL: @test_indirect_call(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    call void @f(i32 0) [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    ret void
+;
+entry:
+  %token = call token @llvm.experimental.convergence.entry()
+  call void @make_indirect_call(ptr @convergent_callee, i32 0) [ "convergencectrl"(token %token) ]
+  ret void
+}
+
+define void @recurse() convergent alwaysinline {
+; CHECK-LABEL: @recurse(
+; CHECK-NEXT:    [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    call void @recurse() [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    ret void
+;
+  %token = call token @llvm.experimental.convergence.entry()
+  call void @recurse() [ "convergencectrl"(token %token) ]
+  ret void
+}
+
+define void @test_recursive_call() convergent {
+; CHECK-LABEL: @test_recursive_call(
+; CHECK-NEXT:    [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    call void @recurse() [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    ret void
+;
+  %token = call token @llvm.experimental.convergence.entry()
+  call void @recurse() [ "convergencectrl"(token %token) ]
+  ret void
+}
+
+define i32 @outer_g(i32 %x) convergent alwaysinline {
+; CHECK-LABEL: @outer_g(
+; CHECK-NEXT:    [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    [[Y:%.*]] = call i32 @g(i32 [[X:%.*]]) [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    ret i32 [[Y]]
+;
+  %token = call token @llvm.experimental.convergence.entry()
+  %y = call i32 @g(i32 %x) [ "convergencectrl"(token %token) ]
+  ret i32 %y
+}
+
+define void @test_two_calls() convergent {
+; CHECK-LABEL: @test_two_calls(
+; CHECK-NEXT:    [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    [[Y_I:%.*]] = call i32 @g(i32 23) [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    call void @f(i32 [[Y_I]]) [ "convergencectrl"(token [[TOKEN]]) ]
+; CHECK-NEXT:    ret void
+;
+  %token = call token @llvm.experimental.convergence.entry()
+  %x = call i32 @outer_g(i32 23) [ "convergencectrl"(token %token) ]
+  call void @convergent_callee(i32 %x) [ "convergencectrl"(token %token) ]
+  ret void
+}
+
+declare void @f(i32) convergent
+declare i32 @g(i32) convergent
+
+declare token @llvm.experimental.convergence.entry()
+declare token @llvm.experimental.convergence.anchor()
+declare token @llvm.experimental.convergence.loop()


        


More information about the llvm-commits mailing list