[llvm-branch-commits] [llvm] [LLVM][Coroutines] Transform "coro_elide_safe" calls to switch ABI coroutines to the `noalloc` variant (PR #99285)
Yuxuan Chen via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Sep 8 21:47:40 PDT 2024
https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/99285
>From 201bca06d4e75bc4fa24ac269ad7b9750f24616f Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Mon, 15 Jul 2024 15:01:39 -0700
Subject: [PATCH] [LLVM][Coroutines] Transform "coro_elide_safe" calls to
switch ABI coroutines to the `noalloc` variant
---
.../Coroutines/CoroAnnotationElide.h | 36 ++++
llvm/lib/Passes/PassBuilder.cpp | 1 +
llvm/lib/Passes/PassBuilderPipelines.cpp | 10 +-
llvm/lib/Passes/PassRegistry.def | 1 +
llvm/lib/Transforms/Coroutines/CMakeLists.txt | 1 +
.../Coroutines/CoroAnnotationElide.cpp | 155 ++++++++++++++++++
llvm/test/Other/new-pm-defaults.ll | 1 +
.../Other/new-pm-thinlto-postlink-defaults.ll | 1 +
.../new-pm-thinlto-postlink-pgo-defaults.ll | 1 +
...-pm-thinlto-postlink-samplepgo-defaults.ll | 1 +
.../Coroutines/coro-transform-must-elide.ll | 75 +++++++++
11 files changed, 281 insertions(+), 2 deletions(-)
create mode 100644 llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
create mode 100644 llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
create mode 100644 llvm/test/Transforms/Coroutines/coro-transform-must-elide.ll
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
new file mode 100644
index 00000000000000..352c9e14526697
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
@@ -0,0 +1,36 @@
+//===- CoroAnnotationElide.h - Elide attributed safe coroutine calls ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// This pass transforms all Call or Invoke instructions that are annotated
+// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
+// The frame of the callee coroutine is allocated inside the caller. A pointer
+// to the allocated frame will be passed into the `.noalloc` ramp function.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
+#define LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
+
+#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/Analysis/LazyCallGraph.h"
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+struct CoroAnnotationElidePass : PassInfoMixin<CoroAnnotationElidePass> {
+ CoroAnnotationElidePass() {}
+
+ PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
+ LazyCallGraph &CG, CGSCCUpdateResult &UR);
+
+ static bool isRequired() { return false; }
+};
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 83c1a6712bf4d9..c34f9148cce58b 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -139,6 +139,7 @@
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
#include "llvm/Transforms/CFGuard.h"
+#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
#include "llvm/Transforms/Coroutines/CoroCleanup.h"
#include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h"
#include "llvm/Transforms/Coroutines/CoroEarly.h"
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 7f9e1362e7ef23..4e8e3dcdff4428 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -33,6 +33,7 @@
#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
+#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
#include "llvm/Transforms/Coroutines/CoroCleanup.h"
#include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h"
#include "llvm/Transforms/Coroutines/CoroEarly.h"
@@ -973,8 +974,10 @@ PassBuilder::buildInlinerPipeline(OptimizationLevel Level,
MainCGPipeline.addPass(createCGSCCToFunctionPassAdaptor(
RequireAnalysisPass<ShouldNotRunFunctionPassesAnalysis, Function>()));
- if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink)
+ if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
MainCGPipeline.addPass(CoroSplitPass(Level != OptimizationLevel::O0));
+ MainCGPipeline.addPass(CoroAnnotationElidePass());
+ }
// Make sure we don't affect potential future NoRerun CGSCC adaptors.
MIWP.addLateModulePass(createModuleToFunctionPassAdaptor(
@@ -1016,9 +1019,12 @@ PassBuilder::buildModuleInlinerPipeline(OptimizationLevel Level,
buildFunctionSimplificationPipeline(Level, Phase),
PTO.EagerlyInvalidateAnalyses));
- if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink)
+ if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(
CoroSplitPass(Level != OptimizationLevel::O0)));
+ MPM.addPass(
+ createModuleToPostOrderCGSCCPassAdaptor(CoroAnnotationElidePass()));
+ }
return MPM;
}
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 1769c2496a7b8e..4f5f680a6e9535 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -243,6 +243,7 @@ CGSCC_PASS("attributor-light-cgscc", AttributorLightCGSCCPass())
CGSCC_PASS("invalidate<all>", InvalidateAllAnalysesPass())
CGSCC_PASS("no-op-cgscc", NoOpCGSCCPass())
CGSCC_PASS("openmp-opt-cgscc", OpenMPOptCGSCCPass())
+CGSCC_PASS("coro-annotation-elide", CoroAnnotationElidePass())
#undef CGSCC_PASS
#ifndef CGSCC_PASS_WITH_PARAMS
diff --git a/llvm/lib/Transforms/Coroutines/CMakeLists.txt b/llvm/lib/Transforms/Coroutines/CMakeLists.txt
index 2139446e5ff957..b4b5812d97d893 100644
--- a/llvm/lib/Transforms/Coroutines/CMakeLists.txt
+++ b/llvm/lib/Transforms/Coroutines/CMakeLists.txt
@@ -1,5 +1,6 @@
add_llvm_component_library(LLVMCoroutines
Coroutines.cpp
+ CoroAnnotationElide.cpp
CoroCleanup.cpp
CoroConditionalWrapper.cpp
CoroEarly.cpp
diff --git a/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
new file mode 100644
index 00000000000000..27a370a5d8fbf9
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
@@ -0,0 +1,155 @@
+//===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// This pass transforms all Call or Invoke instructions that are annotated
+// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
+// The frame of the callee coroutine is allocated inside the caller. A pointer
+// to the allocated frame will be passed into the `.noalloc` ramp function.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
+
+#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/Analysis/LazyCallGraph.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Analysis.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Transforms/Utils/CallGraphUpdater.h"
+
+#include <cassert>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "coro-annotation-elide"
+
+static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
+ for (Instruction &I : F->getEntryBlock())
+ if (!isa<AllocaInst>(&I))
+ return &I;
+ llvm_unreachable("no terminator in the entry block");
+}
+
+// Create an alloca in the caller, using FrameSize and FrameAlign as the callee
+// coroutine's activation frame.
+static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
+ Align FrameAlign) {
+ LLVMContext &C = Caller->getContext();
+ BasicBlock::iterator InsertPt =
+ getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
+ const DataLayout &DL = Caller->getDataLayout();
+ auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
+ auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
+ Frame->setAlignment(FrameAlign);
+ return Frame;
+}
+
+// Given a call or invoke instruction to the elide safe coroutine, this function
+// does the following:
+// - Allocate a frame for the callee coroutine in the caller using alloca.
+// - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
+// pointer to the frame as an additional argument to NewCallee.
+static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
+ uint64_t FrameSize, Align FrameAlign) {
+ // TODO: generate the lifetime intrinsics for the new frame. This will require
+ // introduction of two pesudo lifetime intrinsics in the frontend around the
+ // `co_await` expression and convert them to real lifetime intrinsics here.
+ auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign);
+ auto NewCBInsertPt = CB->getIterator();
+ llvm::CallBase *NewCB = nullptr;
+ SmallVector<Value *, 4> NewArgs;
+ NewArgs.append(CB->arg_begin(), CB->arg_end());
+ NewArgs.push_back(FramePtr);
+
+ if (auto *CI = dyn_cast<CallInst>(CB)) {
+ auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee,
+ NewArgs, "", NewCBInsertPt);
+ NewCI->setTailCallKind(CI->getTailCallKind());
+ NewCB = NewCI;
+ } else if (auto *II = dyn_cast<InvokeInst>(CB)) {
+ NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee,
+ II->getNormalDest(), II->getUnwindDest(),
+ NewArgs, std::nullopt, "", NewCBInsertPt);
+ } else {
+ llvm_unreachable("CallBase should either be Call or Invoke!");
+ }
+
+ NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee);
+ NewCB->setCallingConv(CB->getCallingConv());
+ NewCB->setAttributes(CB->getAttributes());
+ NewCB->setDebugLoc(CB->getDebugLoc());
+ std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(),
+ NewCB->bundle_op_info_begin());
+
+ NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe);
+ CB->replaceAllUsesWith(NewCB);
+ CB->eraseFromParent();
+}
+
+PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C,
+ CGSCCAnalysisManager &AM,
+ LazyCallGraph &CG,
+ CGSCCUpdateResult &UR) {
+ bool Changed = false;
+ CallGraphUpdater CGUpdater;
+ CGUpdater.initialize(CG, C, AM, UR);
+
+ auto &FAM =
+ AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
+
+ for (LazyCallGraph::Node &N : C) {
+ Function *Callee = &N.getFunction();
+ Function *NewCallee = Callee->getParent()->getFunction(
+ (Callee->getName() + ".noalloc").str());
+ if (!NewCallee) {
+ continue;
+ }
+
+ auto FramePtrArgPosition = NewCallee->arg_size() - 1;
+ auto FrameSize =
+ NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
+ auto FrameAlign =
+ NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
+
+ SmallVector<CallBase *, 4> Users;
+ for (auto *U : Callee->users()) {
+ if (auto *CB = dyn_cast<CallBase>(U)) {
+ if (CB->getCalledFunction() == Callee)
+ Users.push_back(CB);
+ }
+ }
+
+ auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(*Callee);
+
+ for (auto *CB : Users) {
+ auto *Caller = CB->getFunction();
+ if (Caller && Caller->isPresplitCoroutine() &&
+ CB->hasFnAttr(llvm::Attribute::CoroElideSafe)) {
+
+ auto *CallerN = CG.lookup(*Caller);
+ auto *CallerC = CG.lookupSCC(*CallerN);
+ processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
+
+ ORE.emit([&]() {
+ return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
+ << "'" << ore::NV("callee", Callee->getName())
+ << "' elided in '" << ore::NV("caller", Caller->getName());
+ });
+ Changed = true;
+ updateCGAndAnalysisManagerForCGSCCPass(CG, *CallerC, *CallerN, AM, UR,
+ FAM);
+ }
+ }
+ }
+ return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
+}
diff --git a/llvm/test/Other/new-pm-defaults.ll b/llvm/test/Other/new-pm-defaults.ll
index 588337c15625e6..55dbdb1b8366d6 100644
--- a/llvm/test/Other/new-pm-defaults.ll
+++ b/llvm/test/Other/new-pm-defaults.ll
@@ -226,6 +226,7 @@
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Running pass: CoroSplitPass
+; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis
diff --git a/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll b/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll
index 064362eabbf839..fcf84dc5e11051 100644
--- a/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll
+++ b/llvm/test/Other/new-pm-thinlto-postlink-defaults.ll
@@ -153,6 +153,7 @@
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Running pass: CoroSplitPass
+; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis
diff --git a/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll b/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll
index 19a44867e434ac..4d5b5e733a87c2 100644
--- a/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll
+++ b/llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll
@@ -137,6 +137,7 @@
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Running pass: CoroSplitPass
+; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis
diff --git a/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll b/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll
index 9c2025f7d1ec39..62b81ac7cad03f 100644
--- a/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll
+++ b/llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll
@@ -146,6 +146,7 @@
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Running pass: CoroSplitPass
+; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis
diff --git a/llvm/test/Transforms/Coroutines/coro-transform-must-elide.ll b/llvm/test/Transforms/Coroutines/coro-transform-must-elide.ll
new file mode 100644
index 00000000000000..a4e575f6c03816
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-transform-must-elide.ll
@@ -0,0 +1,75 @@
+; Testing elide performed its job for calls to coroutines marked safe.
+; RUN: opt < %s -S -passes='cgscc(coro-annotation-elide)' | FileCheck %s
+
+%struct.Task = type { ptr }
+
+declare void @print(i32) nounwind
+
+; resume part of the coroutine
+define fastcc void @callee.resume(ptr dereferenceable(1)) {
+ tail call void @print(i32 0)
+ ret void
+}
+
+; destroy part of the coroutine
+define fastcc void @callee.destroy(ptr) {
+ tail call void @print(i32 1)
+ ret void
+}
+
+; cleanup part of the coroutine
+define fastcc void @callee.cleanup(ptr) {
+ tail call void @print(i32 2)
+ ret void
+}
+
+ at callee.resumers = internal constant [3 x ptr] [
+ ptr @callee.resume, ptr @callee.destroy, ptr @callee.cleanup]
+
+declare void @alloc(i1) nounwind
+
+; CHECK-LABEL: define ptr @callee
+define ptr @callee(i8 %arg) {
+entry:
+ %task = alloca %struct.Task, align 8
+ %id = call token @llvm.coro.id(i32 0, ptr null,
+ ptr @callee,
+ ptr @callee.resumers)
+ %alloc = call i1 @llvm.coro.alloc(token %id)
+ %hdl = call ptr @llvm.coro.begin(token %id, ptr null)
+ store ptr %hdl, ptr %task
+ ret ptr %task
+}
+
+; CHECK-LABEL: define ptr @callee.noalloc
+define ptr @callee.noalloc(i8 %arg, ptr dereferenceable(32) align(8) %frame) {
+ entry:
+ %task = alloca %struct.Task, align 8
+ %id = call token @llvm.coro.id(i32 0, ptr null,
+ ptr @callee,
+ ptr @callee.resumers)
+ %hdl = call ptr @llvm.coro.begin(token %id, ptr null)
+ store ptr %hdl, ptr %task
+ ret ptr %task
+}
+
+; CHECK-LABEL: define ptr @caller()
+; Function Attrs: presplitcoroutine
+define ptr @caller() #0 {
+entry:
+ %task = call ptr @callee(i8 0) #1
+ ret ptr %task
+
+ ; CHECK: %[[FRAME:.+]] = alloca [32 x i8], align 8
+ ; CHECK-NEXT: %[[TASK:.+]] = call ptr @callee.noalloc(i8 0, ptr %[[FRAME]])
+ ; CHECK-NEXT: ret ptr %[[TASK]]
+}
+
+declare token @llvm.coro.id(i32, ptr, ptr, ptr)
+declare ptr @llvm.coro.begin(token, ptr)
+declare ptr @llvm.coro.frame()
+declare ptr @llvm.coro.subfn.addr(ptr, i8)
+declare i1 @llvm.coro.alloc(token)
+
+attributes #0 = { presplitcoroutine }
+attributes #1 = { coro_elide_safe }
More information about the llvm-branch-commits
mailing list