[llvm] [Coro][WebAssembly] Add tail-call check for async lowering (PR #81481)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 12 05:44:47 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: Yuta Saito (kateinoigakukun)
<details>
<summary>Changes</summary>
This patch fixes a verifier error when async lowering is used for WebAssembly target without tail-call feature. This missing check was revealed by b1ac052ab07ea091c90c2b7c89445b2bfcfa42ab, which removed inlining of the musttail'ed call and it started leaving the invalid call at the verification stage. Additionally, `TTI::supportsTailCallFor` did not respect the concrete TTI's `supportsTailCalls` implementation, so it always returned true even though `supportsTailCalls` returned false, so this patch also fixes the wrong CRTP base class implementation.
---
Full diff: https://github.com/llvm/llvm-project/pull/81481.diff
5 Files Affected:
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+4-4)
- (modified) llvm/lib/Transforms/Coroutines/CoroFrame.cpp (+2-2)
- (modified) llvm/lib/Transforms/Coroutines/CoroInternal.h (+3-1)
- (modified) llvm/lib/Transforms/Coroutines/CoroSplit.cpp (+10-5)
- (added) llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll (+36)
``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 3d5db96e86b804..13379cc126a40c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -367,10 +367,6 @@ class TargetTransformInfoImplBase {
bool supportsTailCalls() const { return true; }
- bool supportsTailCallFor(const CallBase *CB) const {
- return supportsTailCalls();
- }
-
bool enableAggressiveInterleaving(bool LoopHasReductions) const {
return false;
}
@@ -1427,6 +1423,10 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
I, Ops, TargetTransformInfo::TCK_SizeAndLatency);
return Cost >= TargetTransformInfo::TCC_Expensive;
}
+
+ bool supportsTailCallFor(const CallBase *CB) const {
+ return static_cast<const T *>(this)->supportsTailCalls();
+ }
};
} // namespace llvm
diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index e69c718f0ae3ac..994871eb126884 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -3064,7 +3064,7 @@ static void doRematerializations(
}
void coro::buildCoroutineFrame(
- Function &F, Shape &Shape,
+ Function &F, Shape &Shape, TargetTransformInfo &TTI,
const std::function<bool(Instruction &)> &MaterializableCallback) {
// Don't eliminate swifterror in async functions that won't be split.
if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
@@ -3100,7 +3100,7 @@ void coro::buildCoroutineFrame(
SmallVector<Value *, 8> Args(AsyncEnd->args());
auto Arguments = ArrayRef<Value *>(Args).drop_front(3);
auto *Call = createMustTailCall(AsyncEnd->getDebugLoc(), MustTailCallFn,
- Arguments, Builder);
+ TTI, Arguments, Builder);
splitAround(Call, "MustTailCall.Before.CoroEnd");
}
}
diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
index fb16a4090689b4..388cf8d2aee71c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -12,6 +12,7 @@
#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H
#include "CoroInstr.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/IRBuilder.h"
namespace llvm {
@@ -272,9 +273,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
bool defaultMaterializable(Instruction &V);
void buildCoroutineFrame(
- Function &F, Shape &Shape,
+ Function &F, Shape &Shape, TargetTransformInfo &TTI,
const std::function<bool(Instruction &)> &MaterializableCallback);
CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
+ TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments, IRBuilder<> &);
} // End namespace coro.
} // End namespace llvm
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index aed4cd027d0338..47367d0b84edec 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -1746,6 +1746,7 @@ static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy,
}
CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
+ TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments,
IRBuilder<> &Builder) {
auto *FnTy = MustTailCallFn->getFunctionType();
@@ -1755,14 +1756,18 @@ CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
coerceArguments(Builder, FnTy, Arguments, CallArgs);
auto *TailCall = Builder.CreateCall(FnTy, MustTailCallFn, CallArgs);
- TailCall->setTailCallKind(CallInst::TCK_MustTail);
+ // Skip targets which don't support tail call.
+ if (TTI.supportsTailCallFor(TailCall)) {
+ TailCall->setTailCallKind(CallInst::TCK_MustTail);
+ }
TailCall->setDebugLoc(Loc);
TailCall->setCallingConv(MustTailCallFn->getCallingConv());
return TailCall;
}
static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
- SmallVectorImpl<Function *> &Clones) {
+ SmallVectorImpl<Function *> &Clones,
+ TargetTransformInfo &TTI) {
assert(Shape.ABI == coro::ABI::Async);
assert(Clones.empty());
// Reset various things that the optimizer might have decided it
@@ -1837,7 +1842,7 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
SmallVector<Value *, 8> Args(Suspend->args());
auto FnArgs = ArrayRef<Value *>(Args).drop_front(
CoroSuspendAsyncInst::MustTailCallFuncArg + 1);
- coro::createMustTailCall(Suspend->getDebugLoc(), Fn, FnArgs, Builder);
+ coro::createMustTailCall(Suspend->getDebugLoc(), Fn, TTI, FnArgs, Builder);
Builder.CreateRetVoid();
// Replace the lvm.coro.async.resume intrisic call.
@@ -2010,7 +2015,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
return Shape;
simplifySuspendPoints(Shape);
- buildCoroutineFrame(F, Shape, MaterializableCallback);
+ buildCoroutineFrame(F, Shape, TTI, MaterializableCallback);
replaceFrameSizeAndAlignment(Shape);
// If there are no suspend points, no split required, just remove
@@ -2023,7 +2028,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
break;
case coro::ABI::Async:
- splitAsyncCoroutine(F, Shape, Clones);
+ splitAsyncCoroutine(F, Shape, Clones, TTI);
break;
case coro::ABI::Retcon:
case coro::ABI::RetconOnce:
diff --git a/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll b/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
new file mode 100644
index 00000000000000..a64c6b34243ec3
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
@@ -0,0 +1,36 @@
+; Tests that coro-split will convert coro.resume followed by a suspend to a
+; musttail call.
+; RUN: opt < %s -O0 -S -mtriple=wasm32-unknown-unknown | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20"
+target triple = "wasm32-unknown-wasi"
+
+%swift.async_func_pointer = type <{ i32, i32 }>
+
+ at checkTu = global %swift.async_func_pointer <{ i32 ptrtoint (ptr @check to i32), i32 8 }>
+
+define swiftcc void @check(ptr %0) {
+entry:
+ %1 = call token @llvm.coro.id.async(i32 0, i32 0, i32 0, ptr @checkTu)
+ %2 = call ptr @llvm.coro.begin(token %1, ptr null)
+ %3 = call ptr @llvm.coro.async.resume()
+ store ptr %3, ptr %0, align 4
+ %4 = call { ptr, i32 } (i32, ptr, ptr, ...) @llvm.coro.suspend.async.sl_p0i32s(i32 0, ptr %3, ptr @__swift_async_resume_project_context, ptr @check.0, ptr null, ptr null)
+ ret void
+}
+
+declare swiftcc void @check.0()
+declare { ptr, i32 } @llvm.coro.suspend.async.sl_p0i32s(i32, ptr, ptr, ...)
+declare token @llvm.coro.id.async(i32, i32, i32, ptr)
+declare ptr @llvm.coro.begin(token, ptr writeonly)
+declare ptr @llvm.coro.async.resume()
+
+define ptr @__swift_async_resume_project_context(ptr %0) {
+entry:
+ ret ptr null
+}
+
+; Verify that the resume call is not marked as musttail.
+; CHECK-LABEL: define swiftcc void @check(
+; CHECK-NOT: musttail call swiftcc void @check.0()
+; CHECK: call swiftcc void @check.0()
``````````
</details>
https://github.com/llvm/llvm-project/pull/81481
More information about the llvm-commits
mailing list