[llvm] 8c5c4d9 - [Coro][WebAssembly] Add tail-call check for async lowering (#81481)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 19 18:58:48 PST 2024
Author: Yuta Saito
Date: 2024-02-20T11:58:44+09:00
New Revision: 8c5c4d9a63bd0cba7f025431e06f63d032010393
URL: https://github.com/llvm/llvm-project/commit/8c5c4d9a63bd0cba7f025431e06f63d032010393
DIFF: https://github.com/llvm/llvm-project/commit/8c5c4d9a63bd0cba7f025431e06f63d032010393.diff
LOG: [Coro][WebAssembly] Add tail-call check for async lowering (#81481)
This patch fixes a verifier error when async lowering is used for
WebAssembly target without tail-call feature. This missing check was
revealed by b1ac052ab07ea091c90c2b7c89445b2bfcfa42ab, which removed
inlining of the musttail'ed call and it started leaving the invalid call
at the verification stage. Additionally, `TTI::supportsTailCallFor` did
not respect the concrete TTI's `supportsTailCalls` implementation, so it
always returned true even though `supportsTailCalls` returned false, so
this patch also fixes the wrong CRTP base class implementation.
Added:
llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
Modified:
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/lib/Transforms/Coroutines/CoroFrame.cpp
llvm/lib/Transforms/Coroutines/CoroInternal.h
llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 3d5db96e86b804..13379cc126a40c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -367,10 +367,6 @@ class TargetTransformInfoImplBase {
bool supportsTailCalls() const { return true; }
- bool supportsTailCallFor(const CallBase *CB) const {
- return supportsTailCalls();
- }
-
bool enableAggressiveInterleaving(bool LoopHasReductions) const {
return false;
}
@@ -1427,6 +1423,10 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
I, Ops, TargetTransformInfo::TCK_SizeAndLatency);
return Cost >= TargetTransformInfo::TCC_Expensive;
}
+
+ bool supportsTailCallFor(const CallBase *CB) const {
+ return static_cast<const T *>(this)->supportsTailCalls();
+ }
};
} // namespace llvm
diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index e69c718f0ae3ac..994871eb126884 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -3064,7 +3064,7 @@ static void doRematerializations(
}
void coro::buildCoroutineFrame(
- Function &F, Shape &Shape,
+ Function &F, Shape &Shape, TargetTransformInfo &TTI,
const std::function<bool(Instruction &)> &MaterializableCallback) {
// Don't eliminate swifterror in async functions that won't be split.
if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
@@ -3100,7 +3100,7 @@ void coro::buildCoroutineFrame(
SmallVector<Value *, 8> Args(AsyncEnd->args());
auto Arguments = ArrayRef<Value *>(Args).drop_front(3);
auto *Call = createMustTailCall(AsyncEnd->getDebugLoc(), MustTailCallFn,
- Arguments, Builder);
+ TTI, Arguments, Builder);
splitAround(Call, "MustTailCall.Before.CoroEnd");
}
}
diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
index fb16a4090689b4..388cf8d2aee71c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -12,6 +12,7 @@
#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H
#include "CoroInstr.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/IRBuilder.h"
namespace llvm {
@@ -272,9 +273,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
bool defaultMaterializable(Instruction &V);
void buildCoroutineFrame(
- Function &F, Shape &Shape,
+ Function &F, Shape &Shape, TargetTransformInfo &TTI,
const std::function<bool(Instruction &)> &MaterializableCallback);
CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
+ TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments, IRBuilder<> &);
} // End namespace coro.
} // End namespace llvm
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index aed4cd027d0338..47367d0b84edec 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -1746,6 +1746,7 @@ static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy,
}
CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
+ TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments,
IRBuilder<> &Builder) {
auto *FnTy = MustTailCallFn->getFunctionType();
@@ -1755,14 +1756,18 @@ CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
coerceArguments(Builder, FnTy, Arguments, CallArgs);
auto *TailCall = Builder.CreateCall(FnTy, MustTailCallFn, CallArgs);
- TailCall->setTailCallKind(CallInst::TCK_MustTail);
+ // Skip targets which don't support tail call.
+ if (TTI.supportsTailCallFor(TailCall)) {
+ TailCall->setTailCallKind(CallInst::TCK_MustTail);
+ }
TailCall->setDebugLoc(Loc);
TailCall->setCallingConv(MustTailCallFn->getCallingConv());
return TailCall;
}
static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
- SmallVectorImpl<Function *> &Clones) {
+ SmallVectorImpl<Function *> &Clones,
+ TargetTransformInfo &TTI) {
assert(Shape.ABI == coro::ABI::Async);
assert(Clones.empty());
// Reset various things that the optimizer might have decided it
@@ -1837,7 +1842,7 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
SmallVector<Value *, 8> Args(Suspend->args());
auto FnArgs = ArrayRef<Value *>(Args).drop_front(
CoroSuspendAsyncInst::MustTailCallFuncArg + 1);
- coro::createMustTailCall(Suspend->getDebugLoc(), Fn, FnArgs, Builder);
+ coro::createMustTailCall(Suspend->getDebugLoc(), Fn, TTI, FnArgs, Builder);
Builder.CreateRetVoid();
// Replace the lvm.coro.async.resume intrisic call.
@@ -2010,7 +2015,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
return Shape;
simplifySuspendPoints(Shape);
- buildCoroutineFrame(F, Shape, MaterializableCallback);
+ buildCoroutineFrame(F, Shape, TTI, MaterializableCallback);
replaceFrameSizeAndAlignment(Shape);
// If there are no suspend points, no split required, just remove
@@ -2023,7 +2028,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
break;
case coro::ABI::Async:
- splitAsyncCoroutine(F, Shape, Clones);
+ splitAsyncCoroutine(F, Shape, Clones, TTI);
break;
case coro::ABI::Retcon:
case coro::ABI::RetconOnce:
diff --git a/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll b/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
new file mode 100644
index 00000000000000..36210b6c87ef66
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
@@ -0,0 +1,31 @@
+; RUN: opt < %s -O0 -S -mtriple=wasm32-unknown-unknown | FileCheck %s
+; REQUIRES: webassembly-registered-target
+
+%swift.async_func_pointer = type <{ i32, i32 }>
+ at checkTu = global %swift.async_func_pointer <{ i32 ptrtoint (ptr @check to i32), i32 8 }>
+
+define swiftcc void @check(ptr %0) {
+entry:
+ %1 = call token @llvm.coro.id.async(i32 0, i32 0, i32 0, ptr @checkTu)
+ %2 = call ptr @llvm.coro.begin(token %1, ptr null)
+ %3 = call ptr @llvm.coro.async.resume()
+ store ptr %3, ptr %0, align 4
+ %4 = call { ptr, i32 } (i32, ptr, ptr, ...) @llvm.coro.suspend.async.sl_p0i32s(i32 0, ptr %3, ptr @__swift_async_resume_project_context, ptr @check.0, ptr null, ptr null)
+ ret void
+}
+
+declare swiftcc void @check.0()
+declare { ptr, i32 } @llvm.coro.suspend.async.sl_p0i32s(i32, ptr, ptr, ...)
+declare token @llvm.coro.id.async(i32, i32, i32, ptr)
+declare ptr @llvm.coro.begin(token, ptr writeonly)
+declare ptr @llvm.coro.async.resume()
+
+define ptr @__swift_async_resume_project_context(ptr %0) {
+entry:
+ ret ptr null
+}
+
+; Verify that the resume call is not marked as musttail.
+; CHECK-LABEL: define swiftcc void @check(
+; CHECK-NOT: musttail call swiftcc void @check.0()
+; CHECK: call swiftcc void @check.0()
More information about the llvm-commits
mailing list