[llvm] [Coro][WebAssembly] Add tail-call check for async lowering (PR #81481)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 12 05:44:47 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Yuta Saito (kateinoigakukun)

<details>
<summary>Changes</summary>

This patch fixes a verifier error when async lowering is used for WebAssembly target without tail-call feature. This missing check was revealed by b1ac052ab07ea091c90c2b7c89445b2bfcfa42ab, which removed inlining of the musttail'ed call and it started leaving the invalid call at the verification stage. Additionally, `TTI::supportsTailCallFor` did not respect the concrete TTI's `supportsTailCalls` implementation, so it always returned true even though `supportsTailCalls` returned false, so this patch also fixes the wrong CRTP base class implementation.

---
Full diff: https://github.com/llvm/llvm-project/pull/81481.diff


5 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+4-4) 
- (modified) llvm/lib/Transforms/Coroutines/CoroFrame.cpp (+2-2) 
- (modified) llvm/lib/Transforms/Coroutines/CoroInternal.h (+3-1) 
- (modified) llvm/lib/Transforms/Coroutines/CoroSplit.cpp (+10-5) 
- (added) llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll (+36) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 3d5db96e86b804..13379cc126a40c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -367,10 +367,6 @@ class TargetTransformInfoImplBase {
 
   bool supportsTailCalls() const { return true; }
 
-  bool supportsTailCallFor(const CallBase *CB) const {
-    return supportsTailCalls();
-  }
-
   bool enableAggressiveInterleaving(bool LoopHasReductions) const {
     return false;
   }
@@ -1427,6 +1423,10 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
         I, Ops, TargetTransformInfo::TCK_SizeAndLatency);
     return Cost >= TargetTransformInfo::TCC_Expensive;
   }
+
+  bool supportsTailCallFor(const CallBase *CB) const {
+    return static_cast<const T *>(this)->supportsTailCalls();
+  }
 };
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index e69c718f0ae3ac..994871eb126884 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -3064,7 +3064,7 @@ static void doRematerializations(
 }
 
 void coro::buildCoroutineFrame(
-    Function &F, Shape &Shape,
+    Function &F, Shape &Shape, TargetTransformInfo &TTI,
     const std::function<bool(Instruction &)> &MaterializableCallback) {
   // Don't eliminate swifterror in async functions that won't be split.
   if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
@@ -3100,7 +3100,7 @@ void coro::buildCoroutineFrame(
       SmallVector<Value *, 8> Args(AsyncEnd->args());
       auto Arguments = ArrayRef<Value *>(Args).drop_front(3);
       auto *Call = createMustTailCall(AsyncEnd->getDebugLoc(), MustTailCallFn,
-                                      Arguments, Builder);
+                                      TTI, Arguments, Builder);
       splitAround(Call, "MustTailCall.Before.CoroEnd");
     }
   }
diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
index fb16a4090689b4..388cf8d2aee71c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -12,6 +12,7 @@
 #define LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H
 
 #include "CoroInstr.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/IRBuilder.h"
 
 namespace llvm {
@@ -272,9 +273,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
 
 bool defaultMaterializable(Instruction &V);
 void buildCoroutineFrame(
-    Function &F, Shape &Shape,
+    Function &F, Shape &Shape, TargetTransformInfo &TTI,
     const std::function<bool(Instruction &)> &MaterializableCallback);
 CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
+                             TargetTransformInfo &TTI,
                              ArrayRef<Value *> Arguments, IRBuilder<> &);
 } // End namespace coro.
 } // End namespace llvm
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index aed4cd027d0338..47367d0b84edec 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -1746,6 +1746,7 @@ static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy,
 }
 
 CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
+                                   TargetTransformInfo &TTI,
                                    ArrayRef<Value *> Arguments,
                                    IRBuilder<> &Builder) {
   auto *FnTy = MustTailCallFn->getFunctionType();
@@ -1755,14 +1756,18 @@ CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
   coerceArguments(Builder, FnTy, Arguments, CallArgs);
 
   auto *TailCall = Builder.CreateCall(FnTy, MustTailCallFn, CallArgs);
-  TailCall->setTailCallKind(CallInst::TCK_MustTail);
+  // Skip targets which don't support tail call.
+  if (TTI.supportsTailCallFor(TailCall)) {
+    TailCall->setTailCallKind(CallInst::TCK_MustTail);
+  }
   TailCall->setDebugLoc(Loc);
   TailCall->setCallingConv(MustTailCallFn->getCallingConv());
   return TailCall;
 }
 
 static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
-                                SmallVectorImpl<Function *> &Clones) {
+                                SmallVectorImpl<Function *> &Clones,
+                                TargetTransformInfo &TTI) {
   assert(Shape.ABI == coro::ABI::Async);
   assert(Clones.empty());
   // Reset various things that the optimizer might have decided it
@@ -1837,7 +1842,7 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
     SmallVector<Value *, 8> Args(Suspend->args());
     auto FnArgs = ArrayRef<Value *>(Args).drop_front(
         CoroSuspendAsyncInst::MustTailCallFuncArg + 1);
-    coro::createMustTailCall(Suspend->getDebugLoc(), Fn, FnArgs, Builder);
+    coro::createMustTailCall(Suspend->getDebugLoc(), Fn, TTI, FnArgs, Builder);
     Builder.CreateRetVoid();
 
     // Replace the lvm.coro.async.resume intrisic call.
@@ -2010,7 +2015,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
     return Shape;
 
   simplifySuspendPoints(Shape);
-  buildCoroutineFrame(F, Shape, MaterializableCallback);
+  buildCoroutineFrame(F, Shape, TTI, MaterializableCallback);
   replaceFrameSizeAndAlignment(Shape);
 
   // If there are no suspend points, no split required, just remove
@@ -2023,7 +2028,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
       SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
       break;
     case coro::ABI::Async:
-      splitAsyncCoroutine(F, Shape, Clones);
+      splitAsyncCoroutine(F, Shape, Clones, TTI);
       break;
     case coro::ABI::Retcon:
     case coro::ABI::RetconOnce:
diff --git a/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll b/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
new file mode 100644
index 00000000000000..a64c6b34243ec3
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
@@ -0,0 +1,36 @@
+; Tests that coro-split will convert coro.resume followed by a suspend to a
+; musttail call.
+; RUN: opt < %s -O0 -S -mtriple=wasm32-unknown-unknown | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20"
+target triple = "wasm32-unknown-wasi"
+
+%swift.async_func_pointer = type <{ i32, i32 }>
+
+ at checkTu = global %swift.async_func_pointer <{ i32 ptrtoint (ptr @check to i32), i32 8 }>
+
+define swiftcc void @check(ptr %0) {
+entry:
+  %1 = call token @llvm.coro.id.async(i32 0, i32 0, i32 0, ptr @checkTu)
+  %2 = call ptr @llvm.coro.begin(token %1, ptr null)
+  %3 = call ptr @llvm.coro.async.resume()
+  store ptr %3, ptr %0, align 4
+  %4 = call { ptr, i32 } (i32, ptr, ptr, ...) @llvm.coro.suspend.async.sl_p0i32s(i32 0, ptr %3, ptr @__swift_async_resume_project_context, ptr @check.0, ptr null, ptr null)
+  ret void
+}
+
+declare swiftcc void @check.0()
+declare { ptr, i32 } @llvm.coro.suspend.async.sl_p0i32s(i32, ptr, ptr, ...)
+declare token @llvm.coro.id.async(i32, i32, i32, ptr)
+declare ptr @llvm.coro.begin(token, ptr writeonly)
+declare ptr @llvm.coro.async.resume()
+
+define ptr @__swift_async_resume_project_context(ptr %0) {
+entry:
+  ret ptr null
+}
+
+; Verify that the resume call is not marked as musttail.
+; CHECK-LABEL: define swiftcc void @check(
+; CHECK-NOT: musttail call swiftcc void @check.0()
+; CHECK:     call swiftcc void @check.0()

``````````

</details>


https://github.com/llvm/llvm-project/pull/81481


More information about the llvm-commits mailing list