[llvm-branch-commits] [llvm] [LLVM][Coroutines] Transform "coro_must_elide" calls to switch ABI coroutines to the `noalloc` variant (PR #99285)

Yuxuan Chen via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Aug 22 10:34:11 PDT 2024


https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/99285

>From 17f5922762c0f5203cccda6f3d09ac1b2623078f Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Mon, 15 Jul 2024 15:01:39 -0700
Subject: [PATCH] add CoroAnnotationElidePass

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: https://phabricator.intern.facebook.com/D60250514
---
 .../Coroutines/CoroAnnotationElide.h          |  36 +++++
 llvm/lib/Passes/PassBuilder.cpp               |   1 +
 llvm/lib/Passes/PassBuilderPipelines.cpp      |  10 +-
 llvm/lib/Passes/PassRegistry.def              |   1 +
 llvm/lib/Transforms/Coroutines/CMakeLists.txt |   1 +
 .../Coroutines/CoroAnnotationElide.cpp        | 141 ++++++++++++++++++
 llvm/test/Other/new-pm-defaults.ll            |   1 +
 .../Other/new-pm-thinlto-postlink-defaults.ll |   1 +
 .../new-pm-thinlto-postlink-pgo-defaults.ll   |   1 +
 ...-pm-thinlto-postlink-samplepgo-defaults.ll |   1 +
 .../Coroutines/coro-transform-must-elide.ll   |  76 ++++++++++
 11 files changed, 268 insertions(+), 2 deletions(-)
 create mode 100644 llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
 create mode 100644 llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
 create mode 100644 llvm/test/Transforms/Coroutines/coro-transform-must-elide.ll

diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
new file mode 100644
index 00000000000000..2b8885c074635d
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
@@ -0,0 +1,36 @@
+//===- CoroAnnotationElide.h - Optimizing a coro_elide_safe call ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// This pass transforms all Call or Invoke instructions that are annotated
+// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
+// The frame of the callee coroutine is allocated inside the caller. A pointer
+// to the allocated frame will be passed into the `.noalloc` ramp function.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
+#define LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
+
+#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/Analysis/LazyCallGraph.h"
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+struct CoroAnnotationElidePass : PassInfoMixin<CoroAnnotationElidePass> {
+  CoroAnnotationElidePass() {}
+
+  PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
+                        LazyCallGraph &CG, CGSCCUpdateResult &UR);
+
+  static bool isRequired() { return false; }
+};
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 17eed97fd950c9..c2b99a0d1f8cea 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -138,6 +138,7 @@
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
 #include "llvm/Transforms/CFGuard.h"
+#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
 #include "llvm/Transforms/Coroutines/CoroCleanup.h"
 #include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h"
 #include "llvm/Transforms/Coroutines/CoroEarly.h"
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 1184123c7710f0..992b4fca8a6919 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -33,6 +33,7 @@
 #include "llvm/Support/VirtualFileSystem.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
+#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
 #include "llvm/Transforms/Coroutines/CoroCleanup.h"
 #include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h"
 #include "llvm/Transforms/Coroutines/CoroEarly.h"
@@ -984,8 +985,10 @@ PassBuilder::buildInlinerPipeline(OptimizationLevel Level,
   MainCGPipeline.addPass(createCGSCCToFunctionPassAdaptor(
       RequireAnalysisPass<ShouldNotRunFunctionPassesAnalysis, Function>()));
 
-  if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink)
+  if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
     MainCGPipeline.addPass(CoroSplitPass(Level != OptimizationLevel::O0));
+    MainCGPipeline.addPass(CoroAnnotationElidePass());
+  }
 
   // Make sure we don't affect potential future NoRerun CGSCC adaptors.
   MIWP.addLateModulePass(createModuleToFunctionPassAdaptor(
@@ -1027,9 +1030,12 @@ PassBuilder::buildModuleInlinerPipeline(OptimizationLevel Level,
       buildFunctionSimplificationPipeline(Level, Phase),
       PTO.EagerlyInvalidateAnalyses));
 
-  if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink)
+  if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
     MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(
         CoroSplitPass(Level != OptimizationLevel::O0)));
+    MPM.addPass(
+        createModuleToPostOrderCGSCCPassAdaptor(CoroAnnotationElidePass()));
+  }
 
   return MPM;
 }
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index d6067089c6b5c1..26550e4f00c64d 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -243,6 +243,7 @@ CGSCC_PASS("attributor-light-cgscc", AttributorLightCGSCCPass())
 CGSCC_PASS("invalidate<all>", InvalidateAllAnalysesPass())
 CGSCC_PASS("no-op-cgscc", NoOpCGSCCPass())
 CGSCC_PASS("openmp-opt-cgscc", OpenMPOptCGSCCPass())
+CGSCC_PASS("coro-annotation-elide", CoroAnnotationElidePass())
 #undef CGSCC_PASS
 
 #ifndef CGSCC_PASS_WITH_PARAMS
diff --git a/llvm/lib/Transforms/Coroutines/CMakeLists.txt b/llvm/lib/Transforms/Coroutines/CMakeLists.txt
index 2139446e5ff957..b4b5812d97d893 100644
--- a/llvm/lib/Transforms/Coroutines/CMakeLists.txt
+++ b/llvm/lib/Transforms/Coroutines/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_llvm_component_library(LLVMCoroutines
   Coroutines.cpp
+  CoroAnnotationElide.cpp
   CoroCleanup.cpp
   CoroConditionalWrapper.cpp
   CoroEarly.cpp
diff --git a/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
new file mode 100644
index 00000000000000..8bbe9b70e22695
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
@@ -0,0 +1,141 @@
+//===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
+
+#include "llvm/Analysis/LazyCallGraph.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Analysis.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Transforms/Utils/CallGraphUpdater.h"
+
+#include <cassert>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "coro-annotation-elide"
+
+static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
+  for (Instruction &I : F->getEntryBlock())
+    if (!isa<AllocaInst>(&I))
+      return &I;
+  llvm_unreachable("no terminator in the entry block");
+}
+
+// Create an alloca in the caller, using FrameSize and FrameAlign as the callee
+// coroutine's activation frame.
+static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
+                                    Align FrameAlign) {
+  LLVMContext &C = Caller->getContext();
+  BasicBlock::iterator InsertPt =
+      getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
+  const DataLayout &DL = Caller->getDataLayout();
+  auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
+  auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
+  Frame->setAlignment(FrameAlign);
+  return new BitCastInst(Frame, PointerType::getUnqual(C), "vFrame", InsertPt);
+}
+
+// Given a call or invoke instruction to the elide safe coroutine, this function
+// does the following:
+//  - Allocate a frame for the callee coroutine in the caller using alloca.
+//  - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
+//    pointer to the frame as an additional argument to NewCallee.
+static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
+                        uint64_t FrameSize, Align FrameAlign) {
+  auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign);
+  auto NewCBInsertPt = CB->getIterator();
+  llvm::CallBase *NewCB = nullptr;
+  SmallVector<Value *, 4> NewArgs;
+  NewArgs.append(CB->arg_begin(), CB->arg_end());
+  NewArgs.push_back(FramePtr);
+
+  if (auto *CI = dyn_cast<CallInst>(CB)) {
+    auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee,
+                                   NewArgs, "", NewCBInsertPt);
+    NewCI->setTailCallKind(CI->getTailCallKind());
+    NewCB = NewCI;
+  } else if (auto *II = dyn_cast<InvokeInst>(CB)) {
+    NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee,
+                               II->getNormalDest(), II->getUnwindDest(),
+                               NewArgs, std::nullopt, "", NewCBInsertPt);
+  } else {
+    llvm_unreachable("CallBase should either be Call or Invoke!");
+  }
+
+  NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee);
+  NewCB->setCallingConv(CB->getCallingConv());
+  NewCB->setAttributes(CB->getAttributes());
+  NewCB->setDebugLoc(CB->getDebugLoc());
+  std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(),
+            NewCB->bundle_op_info_begin());
+
+  NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe);
+  CB->replaceAllUsesWith(NewCB);
+  CB->eraseFromParent();
+}
+
+PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C,
+                                               CGSCCAnalysisManager &AM,
+                                               LazyCallGraph &CG,
+                                               CGSCCUpdateResult &UR) {
+  bool Changed = false;
+  CallGraphUpdater CGUpdater;
+  CGUpdater.initialize(CG, C, AM, UR);
+
+  auto &FAM =
+      AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
+
+  for (LazyCallGraph::Node &N : C) {
+    Function *Callee = &N.getFunction();
+    Function *NewCallee = Callee->getParent()->getFunction(
+        (Callee->getName() + ".noalloc").str());
+    if (!NewCallee) {
+      continue;
+    }
+
+    auto FramePtrArgPosition = NewCallee->arg_size() - 1;
+    auto FrameSize =
+        NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
+    auto FrameAlign =
+        NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
+
+    SmallVector<CallBase *, 4> Users;
+    for (auto *U : Callee->users()) {
+      if (auto *CB = dyn_cast<CallBase>(U)) {
+        if (CB->getCalledFunction() == Callee)
+          Users.push_back(CB);
+      }
+    }
+
+    auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(*Callee);
+
+    for (auto *CB : Users) {
+      auto *Caller = CB->getFunction();
+      if (Caller && Caller->isPresplitCoroutine() &&
+          CB->hasFnAttr(llvm::Attribute::CoroElideSafe)) {
+        processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
+        CGUpdater.reanalyzeFunction(*Caller);
+
+        ORE.emit([&]() {
+          return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
+                 << "'" << ore::NV("callee", Callee->getName())
+                 << "' elided in '" << ore::NV("caller", Caller->getName());
+        });
+        Changed = true;
+      }
+    }
+  }
+  return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
+}
diff --git a/llvm/test/Other/new-pm-defaults.ll b/llvm/test/Other/new-pm-defaults.ll
index 588337c15625e6..55dbdb1b8366d6 100644
--- a/llvm/test/Other/new-pm-defaults.ll
+++ b/llvm/test/Other/new-pm-defaults.ll
@@ -226,6 +226,7 @@
 ; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Running pass: CoroSplitPass
+; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
 ; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis
diff --git a/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll b/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll
index 064362eabbf839..fcf84dc5e11051 100644
--- a/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll
+++ b/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll
@@ -153,6 +153,7 @@
 ; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Running pass: CoroSplitPass
+; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
 ; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis
diff --git a/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll b/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll
index 19a44867e434ac..4d5b5e733a87c2 100644
--- a/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll
+++ b/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll
@@ -137,6 +137,7 @@
 ; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Running pass: CoroSplitPass
+; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
 ; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis
diff --git a/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll b/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll
index e5aebc4850e6db..47f17996011d10 100644
--- a/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll
+++ b/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll
@@ -145,6 +145,7 @@
 ; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Running pass: CoroSplitPass
+; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
 ; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
 ; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis
diff --git a/llvm/test/Transforms/Coroutines/coro-transform-must-elide.ll b/llvm/test/Transforms/Coroutines/coro-transform-must-elide.ll
new file mode 100644
index 00000000000000..3bb5e5874bb019
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-transform-must-elide.ll
@@ -0,0 +1,76 @@
+; Testing elide performed its job for calls to coroutines marked safe.
+; RUN: opt < %s -S -passes='cgscc(coro-annotation-elide)' | FileCheck %s
+
+%struct.Task = type { ptr }
+
+declare void @print(i32) nounwind
+
+; resume part of the coroutine
+define fastcc void @callee.resume(ptr dereferenceable(1)) {
+  tail call void @print(i32 0)
+  ret void
+}
+
+; destroy part of the coroutine
+define fastcc void @callee.destroy(ptr) {
+  tail call void @print(i32 1)
+  ret void
+}
+
+; cleanup part of the coroutine
+define fastcc void @callee.cleanup(ptr) {
+  tail call void @print(i32 2)
+  ret void
+}
+
+ at callee.resumers = internal constant [3 x ptr] [
+  ptr @callee.resume, ptr @callee.destroy, ptr @callee.cleanup]
+
+declare void @alloc(i1) nounwind
+
+; CHECK-LABEL: define ptr @callee
+define ptr @callee(i8 %arg) {
+entry:
+  %task = alloca %struct.Task, align 8
+  %id = call token @llvm.coro.id(i32 0, ptr null,
+                          ptr @callee,
+                          ptr @callee.resumers)
+  %alloc = call i1 @llvm.coro.alloc(token %id)
+  %hdl = call ptr @llvm.coro.begin(token %id, ptr null)
+  store ptr %hdl, ptr %task
+  ret ptr %task
+}
+
+; CHECK-LABEL: define ptr @callee.noalloc
+define ptr @callee.noalloc(i8 %arg, ptr dereferenceable(32) align(8) %frame) {
+ entry:
+  %task = alloca %struct.Task, align 8
+  %id = call token @llvm.coro.id(i32 0, ptr null,
+                          ptr @callee,
+                          ptr @callee.resumers)
+  %hdl = call ptr @llvm.coro.begin(token %id, ptr null)
+  store ptr %hdl, ptr %task
+  ret ptr %task
+}
+
+; CHECK-LABEL: define ptr @caller()
+; Function Attrs: presplitcoroutine
+define ptr @caller() #0 {
+entry:
+  %task = call ptr @callee(i8 0) #1
+  ret ptr %task
+
+  ; CHECK: %[[ALLOCA:.+]] = alloca [32 x i8], align 8
+  ; CHECK-NEXT: %[[FRAME:.+]] = bitcast ptr %[[ALLOCA]] to ptr
+  ; CHECK-NEXT: %[[TASK:.+]] = call ptr @callee.noalloc(i8 0, ptr %[[FRAME]])
+  ; CHECK-NEXT: ret ptr %[[TASK]]
+}
+
+declare token @llvm.coro.id(i32, ptr, ptr, ptr)
+declare ptr @llvm.coro.begin(token, ptr)
+declare ptr @llvm.coro.frame()
+declare ptr @llvm.coro.subfn.addr(ptr, i8)
+declare i1 @llvm.coro.alloc(token)
+
+attributes #0 = { presplitcoroutine }
+attributes #1 = { coro_elide_safe }



More information about the llvm-branch-commits mailing list