[llvm] 918e905 - [WebAssembly] Make stack pointer args inhibit tail calls

Thomas Lively via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 13 16:51:06 PST 2020


Author: Thomas Lively
Date: 2020-02-13T16:43:53-08:00
New Revision: 918e90559b08adebff26c342080c65e79cc223ec

URL: https://github.com/llvm/llvm-project/commit/918e90559b08adebff26c342080c65e79cc223ec
DIFF: https://github.com/llvm/llvm-project/commit/918e90559b08adebff26c342080c65e79cc223ec.diff

LOG: [WebAssembly] Make stack pointer args inhibit tail calls

Summary:
Also make return calls terminator instructions so epilogues are
inserted before them rather than after them. Together, these changes
make WebAssembly's tail call optimization more stack-safe.

Reviewers: aheejin, dschuff

Subscribers: sbc100, jgravelle-google, hiraditya, sunfish, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73943

Added: 
    

Modified: 
    llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
    llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td
    llvm/test/CodeGen/WebAssembly/tailcall.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index e71290608ce9..4d45fb95af20 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -648,32 +648,51 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
     fail(DL, DAG, "WebAssembly doesn't support patch point yet");
 
   if (CLI.IsTailCall) {
-    bool MustTail = CLI.CS && CLI.CS.isMustTailCall();
-    if (Subtarget->hasTailCall() && !CLI.IsVarArg) {
-      // Do not tail call unless caller and callee return types match
-      const Function &F = MF.getFunction();
-      const TargetMachine &TM = getTargetMachine();
-      Type *RetTy = F.getReturnType();
-      SmallVector<MVT, 4> CallerRetTys;
-      SmallVector<MVT, 4> CalleeRetTys;
-      computeLegalValueVTs(F, TM, RetTy, CallerRetTys);
-      computeLegalValueVTs(F, TM, CLI.RetTy, CalleeRetTys);
-      bool TypesMatch = CallerRetTys.size() == CalleeRetTys.size() &&
-                        std::equal(CallerRetTys.begin(), CallerRetTys.end(),
-                                   CalleeRetTys.begin());
-      if (!TypesMatch) {
-        // musttail in this case would be an LLVM IR validation failure
-        assert(!MustTail);
-        CLI.IsTailCall = false;
-      }
-    } else {
+    auto NoTail = [&](const char *Msg) {
+      if (CLI.CS && CLI.CS.isMustTailCall())
+        fail(DL, DAG, Msg);
       CLI.IsTailCall = false;
-      if (MustTail) {
-        if (CLI.IsVarArg) {
-          // The return would pop the argument buffer
-          fail(DL, DAG, "WebAssembly does not support varargs tail calls");
-        } else {
-          fail(DL, DAG, "WebAssembly 'tail-call' feature not enabled");
+    };
+
+    if (!Subtarget->hasTailCall())
+      NoTail("WebAssembly 'tail-call' feature not enabled");
+
+    // Varargs calls cannot be tail calls because the buffer is on the stack
+    if (CLI.IsVarArg)
+      NoTail("WebAssembly does not support varargs tail calls");
+
+    // Do not tail call unless caller and callee return types match
+    const Function &F = MF.getFunction();
+    const TargetMachine &TM = getTargetMachine();
+    Type *RetTy = F.getReturnType();
+    SmallVector<MVT, 4> CallerRetTys;
+    SmallVector<MVT, 4> CalleeRetTys;
+    computeLegalValueVTs(F, TM, RetTy, CallerRetTys);
+    computeLegalValueVTs(F, TM, CLI.RetTy, CalleeRetTys);
+    bool TypesMatch = CallerRetTys.size() == CalleeRetTys.size() &&
+                      std::equal(CallerRetTys.begin(), CallerRetTys.end(),
+                                 CalleeRetTys.begin());
+    if (!TypesMatch)
+      NoTail("WebAssembly tail call requires caller and callee return types to "
+             "match");
+
+    // If pointers to local stack values are passed, we cannot tail call
+    if (CLI.CS) {
+      for (auto &Arg : CLI.CS.args()) {
+        Value *Val = Arg.get();
+        // Trace the value back through pointer operations
+        while (true) {
+          Value *Src = Val->stripPointerCastsAndAliases();
+          if (auto *GEP = dyn_cast<GetElementPtrInst>(Src))
+            Src = GEP->getPointerOperand();
+          if (Val == Src)
+            break;
+          Val = Src;
+        }
+        if (isa<AllocaInst>(Val)) {
+          NoTail(
+              "WebAssembly does not support tail calling with stack arguments");
+          break;
         }
       }
     }

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td
index 703c15d58c93..20b74c6d72d2 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td
@@ -74,7 +74,7 @@ defm CALL_VOID :
     [(WebAssemblycall0 (i32 imm:$callee))],
     "call    \t$callee", "call\t$callee", 0x10>;
 
-let isReturn = 1 in
+let isReturn = 1, isTerminator = 1, hasCtrlDep = 1, isBarrier = 1 in
 defm RET_CALL :
   I<(outs), (ins function32_op:$callee, variable_ops),
     (outs), (ins function32_op:$callee),

diff  --git a/llvm/test/CodeGen/WebAssembly/tailcall.ll b/llvm/test/CodeGen/WebAssembly/tailcall.ll
index f4d4499bcef7..96bd9a67569e 100644
--- a/llvm/test/CodeGen/WebAssembly/tailcall.ll
+++ b/llvm/test/CodeGen/WebAssembly/tailcall.ll
@@ -209,7 +209,37 @@ define i1 @mismatched_return_trunc() {
   ret i1 %u
 }
 
+; Stack-allocated arguments inhibit tail calls
 
+; CHECK-LABEL: stack_arg:
+; CHECK: i32.call
+define i32 @stack_arg(i32* %x) {
+  %a = alloca i32
+  %v = tail call i32 @stack_arg(i32* %a)
+  ret i32 %v
+}
+
+; CHECK-LABEL: stack_arg_gep:
+; CHECK: i32.call
+define i32 @stack_arg_gep(i32* %x) {
+  %a = alloca { i32, i32 }
+  %p = getelementptr { i32, i32 }, { i32, i32 }* %a, i32 0, i32 1
+  %v = tail call i32 @stack_arg_gep(i32* %p)
+  ret i32 %v
+}
+
+; CHECK-LABEL: stack_arg_cast:
+; CHECK: global.get $push{{[0-9]+}}=, __stack_pointer
+; CHECK: global.set __stack_pointer, $pop{{[0-9]+}}
+; FAST: i32.call ${{[0-9]+}}=, stack_arg_cast, $pop{{[0-9]+}}
+; CHECK: global.set __stack_pointer, $pop{{[0-9]+}}
+; SLOW: return_call stack_arg_cast, ${{[0-9]+}}
+define i32 @stack_arg_cast(i32 %x) {
+  %a = alloca [64 x i32]
+  %i = ptrtoint [64 x i32]* %a to i32
+  %v = tail call i32 @stack_arg_cast(i32 %i)
+  ret i32 %v
+}
 
 ; Check that the signatures generated for external indirectly
 ; return-called functions include the proper return types


        


More information about the llvm-commits mailing list