[clang] [llvm] [WebAssembly] Implement an alternative translation for -wasm-enable-sjlj (PR #84137)

YAMAMOTO Takashi via cfe-commits cfe-commits at lists.llvm.org
Wed Mar 6 00:14:31 PST 2024


https://github.com/yamt created https://github.com/llvm/llvm-project/pull/84137

Instead of maintaining per-function-invocation malloc()'ed tables to track which functions each label belongs to, store the equivalent info in jump buffers (jmp_buf) themselves.

Also, use a less emscripten-looking ABI symbols:

    saveSetjmp     -> __wasm_sjlj_setjmp
    testSetjmp     -> __wasm_sjlj_test
    getTempRet0    -> (removed)
    __wasm_longjmp -> __wasm_sjlj_longjmp

Enabled with:

    -mllvm -wasm-enable-sjlj -mllvm -experimental-wasm-enable-alt-sjlj

(-experimental-wasm-enable-alt-sjlj is the new option this change introduces.)

While I want to use this for WASI, it should work for emscripten as well.

An example runtime and a few tests:
https://github.com/yamt/garbage/tree/wasm-sjlj-alt2/wasm/longjmp

Discussion:
https://docs.google.com/document/d/1ZvTPT36K5jjiedF8MCXbEmYjULJjI723aOAks1IdLLg/edit

>From 1283ae6b5536810f8fbe183eda80997aa9f5cdc3 Mon Sep 17 00:00:00 2001
From: YAMAMOTO Takashi <yamamoto at midokura.com>
Date: Fri, 9 Feb 2024 15:49:55 +0900
Subject: [PATCH] [WebAssembly] Implement an alternative translation for
 -wasm-enable-sjlj

Instead of maintaining per-function-invocation malloc()'ed tables to
track which functions each label belongs to, store the equivalent info
in jump buffers (jmp_buf) themselves.

Also, use a less emscripten-looking ABI symbols:

    saveSetjmp     -> __wasm_sjlj_setjmp
    testSetjmp     -> __wasm_sjlj_test
    getTempRet0    -> (removed)
    __wasm_longjmp -> __wasm_sjlj_longjmp

Enabled with:

    -mllvm -wasm-enable-sjlj -mllvm -experimental-wasm-enable-alt-sjlj

(-experimental-wasm-enable-alt-sjlj is the new option this change
introduces.)

While I want to use this for WASI, it should work for emscripten as well.

An example runtime and a few tests:
https://github.com/yamt/garbage/tree/wasm-sjlj-alt2/wasm/longjmp

Discussion:
https://docs.google.com/document/d/1ZvTPT36K5jjiedF8MCXbEmYjULJjI723aOAks1IdLLg/edit
---
 clang/lib/Driver/ToolChains/WebAssembly.cpp   |  14 ++
 .../MCTargetDesc/WebAssemblyMCTargetDesc.cpp  |   3 +
 .../MCTargetDesc/WebAssemblyMCTargetDesc.h    |   1 +
 .../WebAssemblyLowerEmscriptenEHSjLj.cpp      | 174 +++++++++++-------
 .../WebAssembly/WebAssemblyTargetMachine.cpp  |   4 +
 5 files changed, 131 insertions(+), 65 deletions(-)

diff --git a/clang/lib/Driver/ToolChains/WebAssembly.cpp b/clang/lib/Driver/ToolChains/WebAssembly.cpp
index b8c2573d6265fb..2e7c8e6e8d13f7 100644
--- a/clang/lib/Driver/ToolChains/WebAssembly.cpp
+++ b/clang/lib/Driver/ToolChains/WebAssembly.cpp
@@ -386,6 +386,20 @@ void WebAssembly::addClangTargetOptions(const ArgList &DriverArgs,
       // Backend needs '-exception-model=wasm' to use Wasm EH instructions
       CC1Args.push_back("-exception-model=wasm");
     }
+
+    if (Opt.starts_with("-experimental-wasm-enable-alt-sjlj")) {
+      // '-mllvm -experimental-wasm-enable-alt-sjlj'  should be used with
+      // '-mllvm -wasm-enable-sjlj'
+      bool HasWasmEnableSjlj = false;
+      for (const Arg *A : DriverArgs.filtered(options::OPT_mllvm)) {
+        if (StringRef(A->getValue(0)) == "-wasm-enable-sjlj")
+          HasWasmEnableSjlj = true;
+      }
+      if (!HasWasmEnableSjlj)
+        getDriver().Diag(diag::err_drv_argument_only_allowed_with)
+            << "-mllvm -experimental-wasm-enable-alt-sjlj"
+            << "-mllvm -wasm-enable-sjlj";
+    }
   }
 }
 
diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.cpp b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.cpp
index e8f58a19d25e3b..7f15742367be09 100644
--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.cpp
+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.cpp
@@ -54,6 +54,9 @@ cl::opt<bool>
 // setjmp/longjmp handling using wasm EH instrutions
 cl::opt<bool> WebAssembly::WasmEnableSjLj(
     "wasm-enable-sjlj", cl::desc("WebAssembly setjmp/longjmp handling"));
+cl::opt<bool> WebAssembly::WasmEnableAltSjLj(
+    "experimental-wasm-enable-alt-sjlj",
+    cl::desc("Use experimental alternate ABI for --wasm-enable-sjlj"));
 
 static MCAsmInfo *createMCAsmInfo(const MCRegisterInfo & /*MRI*/,
                                   const Triple &TT,
diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
index 15aeaaeb8c4a4e..d23de9d407d894 100644
--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
@@ -44,6 +44,7 @@ extern cl::opt<bool> WasmEnableEmEH;   // asm.js-style EH
 extern cl::opt<bool> WasmEnableEmSjLj; // asm.js-style SjLJ
 extern cl::opt<bool> WasmEnableEH;     // EH using Wasm EH instructions
 extern cl::opt<bool> WasmEnableSjLj;   // SjLj using Wasm EH instructions
+extern cl::opt<bool> WasmEnableAltSjLj; // Alt ABI for WasmEnableSjLj
 
 enum OperandType {
   /// Basic block label in a branch construct.
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp
index 77e6640d5a8224..fc76757011f5d8 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp
@@ -300,6 +300,7 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass {
   bool EnableEmEH;     // Enable Emscripten exception handling
   bool EnableEmSjLj;   // Enable Emscripten setjmp/longjmp handling
   bool EnableWasmSjLj; // Enable Wasm setjmp/longjmp handling
+  bool EnableWasmAltSjLj; // Alt ABI for EnableWasmSjLj
   bool DoSjLj;         // Whether we actually perform setjmp/longjmp handling
 
   GlobalVariable *ThrewGV = nullptr;      // __THREW__ (Emscripten)
@@ -368,7 +369,8 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass {
   WebAssemblyLowerEmscriptenEHSjLj()
       : ModulePass(ID), EnableEmEH(WebAssembly::WasmEnableEmEH),
         EnableEmSjLj(WebAssembly::WasmEnableEmSjLj),
-        EnableWasmSjLj(WebAssembly::WasmEnableSjLj) {
+        EnableWasmSjLj(WebAssembly::WasmEnableSjLj),
+        EnableWasmAltSjLj(WebAssembly::WasmEnableAltSjLj) {
     assert(!(EnableEmSjLj && EnableWasmSjLj) &&
            "Two SjLj modes cannot be turned on at the same time");
     assert(!(EnableEmEH && EnableWasmSjLj) &&
@@ -619,6 +621,7 @@ static bool canLongjmp(const Value *Callee) {
   // There are functions in Emscripten's JS glue code or compiler-rt
   if (CalleeName == "__resumeException" || CalleeName == "llvm_eh_typeid_for" ||
       CalleeName == "saveSetjmp" || CalleeName == "testSetjmp" ||
+      CalleeName == "__wasm_sjlj_setjmp" || CalleeName == "__wasm_sjlj_test" ||
       CalleeName == "getTempRet0" || CalleeName == "setTempRet0")
     return false;
 
@@ -999,7 +1002,11 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
       // Register __wasm_longjmp function, which calls __builtin_wasm_longjmp.
       FunctionType *FTy = FunctionType::get(
           IRB.getVoidTy(), {Int8PtrTy, IRB.getInt32Ty()}, false);
-      WasmLongjmpF = getEmscriptenFunction(FTy, "__wasm_longjmp", &M);
+      if (EnableWasmAltSjLj) {
+        WasmLongjmpF = getEmscriptenFunction(FTy, "__wasm_sjlj_longjmp", &M);
+      } else {
+        WasmLongjmpF = getEmscriptenFunction(FTy, "__wasm_longjmp", &M);
+      }
       WasmLongjmpF->addFnAttr(Attribute::NoReturn);
     }
 
@@ -1007,17 +1014,30 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
       Type *Int8PtrTy = IRB.getPtrTy();
       Type *Int32PtrTy = IRB.getPtrTy();
       Type *Int32Ty = IRB.getInt32Ty();
-      // Register saveSetjmp function
-      FunctionType *SetjmpFTy = SetjmpF->getFunctionType();
-      FunctionType *FTy = FunctionType::get(
-          Int32PtrTy,
-          {SetjmpFTy->getParamType(0), Int32Ty, Int32PtrTy, Int32Ty}, false);
-      SaveSetjmpF = getEmscriptenFunction(FTy, "saveSetjmp", &M);
 
       // Register testSetjmp function
-      FTy = FunctionType::get(Int32Ty,
-                              {getAddrIntType(&M), Int32PtrTy, Int32Ty}, false);
-      TestSetjmpF = getEmscriptenFunction(FTy, "testSetjmp", &M);
+      if (EnableWasmAltSjLj) {
+        // Register saveSetjmp function
+        FunctionType *SetjmpFTy = SetjmpF->getFunctionType();
+        FunctionType *FTy = FunctionType::get(
+            IRB.getVoidTy(), {SetjmpFTy->getParamType(0), Int32Ty, Int32PtrTy},
+            false);
+        SaveSetjmpF = getEmscriptenFunction(FTy, "__wasm_sjlj_setjmp", &M);
+
+        FTy = FunctionType::get(Int32Ty, {Int32PtrTy, Int32PtrTy}, false);
+        TestSetjmpF = getEmscriptenFunction(FTy, "__wasm_sjlj_test", &M);
+      } else {
+        // Register saveSetjmp function
+        FunctionType *SetjmpFTy = SetjmpF->getFunctionType();
+        FunctionType *FTy = FunctionType::get(
+            Int32PtrTy,
+            {SetjmpFTy->getParamType(0), Int32Ty, Int32PtrTy, Int32Ty}, false);
+        SaveSetjmpF = getEmscriptenFunction(FTy, "saveSetjmp", &M);
+
+        FTy = FunctionType::get(
+            Int32Ty, {getAddrIntType(&M), Int32PtrTy, Int32Ty}, false);
+        TestSetjmpF = getEmscriptenFunction(FTy, "testSetjmp", &M);
+      }
 
       // wasm.catch() will be lowered down to wasm 'catch' instruction in
       // instruction selection.
@@ -1291,19 +1311,29 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) {
   Type *IntPtrTy = getAddrIntType(&M);
   Constant *size = ConstantInt::get(IntPtrTy, 40);
   IRB.SetInsertPoint(SetjmpTableSize);
-  auto *SetjmpTable = IRB.CreateMalloc(IntPtrTy, IRB.getInt32Ty(), size,
-                                       nullptr, nullptr, "setjmpTable");
-  SetjmpTable->setDebugLoc(FirstDL);
-  // CallInst::CreateMalloc may return a bitcast instruction if the result types
-  // mismatch. We need to set the debug loc for the original call too.
-  auto *MallocCall = SetjmpTable->stripPointerCasts();
-  if (auto *MallocCallI = dyn_cast<Instruction>(MallocCall)) {
-    MallocCallI->setDebugLoc(FirstDL);
+  Instruction *SetjmpTable;
+  if (EnableWasmAltSjLj) {
+    // This alloca'ed pointer is used by the runtime to identify function
+    // inovactions. It's just for pointer comparisons. It will never
+    // be dereferenced.
+    SetjmpTable = IRB.CreateAlloca(IRB.getInt32Ty());
+    SetjmpTable->setDebugLoc(FirstDL);
+    SetjmpTableInsts.push_back(SetjmpTable);
+  } else {
+    SetjmpTable = IRB.CreateMalloc(IntPtrTy, IRB.getInt32Ty(), size, nullptr,
+                                   nullptr, "setjmpTable");
+    SetjmpTable->setDebugLoc(FirstDL);
+    // CallInst::CreateMalloc may return a bitcast instruction if the result
+    // types mismatch. We need to set the debug loc for the original call too.
+    auto *MallocCall = SetjmpTable->stripPointerCasts();
+    if (auto *MallocCallI = dyn_cast<Instruction>(MallocCall)) {
+      MallocCallI->setDebugLoc(FirstDL);
+    }
+    // setjmpTable[0] = 0;
+    IRB.CreateStore(IRB.getInt32(0), SetjmpTable);
+    SetjmpTableInsts.push_back(SetjmpTable);
+    SetjmpTableSizeInsts.push_back(SetjmpTableSize);
   }
-  // setjmpTable[0] = 0;
-  IRB.CreateStore(IRB.getInt32(0), SetjmpTable);
-  SetjmpTableInsts.push_back(SetjmpTable);
-  SetjmpTableSizeInsts.push_back(SetjmpTableSize);
 
   // Setjmp transformation
   SmallVector<PHINode *, 4> SetjmpRetPHIs;
@@ -1349,14 +1379,20 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) {
     // Our index in the function is our place in the array + 1 to avoid index
     // 0, because index 0 means the longjmp is not ours to handle.
     IRB.SetInsertPoint(CI);
-    Value *Args[] = {CI->getArgOperand(0), IRB.getInt32(SetjmpRetPHIs.size()),
-                     SetjmpTable, SetjmpTableSize};
-    Instruction *NewSetjmpTable =
-        IRB.CreateCall(SaveSetjmpF, Args, "setjmpTable");
-    Instruction *NewSetjmpTableSize =
-        IRB.CreateCall(GetTempRet0F, std::nullopt, "setjmpTableSize");
-    SetjmpTableInsts.push_back(NewSetjmpTable);
-    SetjmpTableSizeInsts.push_back(NewSetjmpTableSize);
+    if (EnableWasmAltSjLj) {
+      Value *Args[] = {CI->getArgOperand(0), IRB.getInt32(SetjmpRetPHIs.size()),
+                       SetjmpTable};
+      IRB.CreateCall(SaveSetjmpF, Args);
+    } else {
+      Value *Args[] = {CI->getArgOperand(0), IRB.getInt32(SetjmpRetPHIs.size()),
+                       SetjmpTable, SetjmpTableSize};
+      Instruction *NewSetjmpTable =
+          IRB.CreateCall(SaveSetjmpF, Args, "setjmpTable");
+      Instruction *NewSetjmpTableSize =
+          IRB.CreateCall(GetTempRet0F, std::nullopt, "setjmpTableSize");
+      SetjmpTableInsts.push_back(NewSetjmpTable);
+      SetjmpTableSizeInsts.push_back(NewSetjmpTableSize);
+    }
     ToErase.push_back(CI);
   }
 
@@ -1372,38 +1408,40 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) {
   for (Instruction *I : ToErase)
     I->eraseFromParent();
 
-  // Free setjmpTable buffer before each return instruction + function-exiting
-  // call
-  SmallVector<Instruction *, 16> ExitingInsts;
-  for (BasicBlock &BB : F) {
-    Instruction *TI = BB.getTerminator();
-    if (isa<ReturnInst>(TI))
-      ExitingInsts.push_back(TI);
-    // Any 'call' instruction with 'noreturn' attribute exits the function at
-    // this point. If this throws but unwinds to another EH pad within this
-    // function instead of exiting, this would have been an 'invoke', which
-    // happens if we use Wasm EH or Wasm SjLJ.
-    for (auto &I : BB) {
-      if (auto *CI = dyn_cast<CallInst>(&I)) {
-        bool IsNoReturn = CI->hasFnAttr(Attribute::NoReturn);
-        if (Function *CalleeF = CI->getCalledFunction())
-          IsNoReturn |= CalleeF->hasFnAttribute(Attribute::NoReturn);
-        if (IsNoReturn)
-          ExitingInsts.push_back(&I);
+  if (!EnableWasmAltSjLj) {
+    // Free setjmpTable buffer before each return instruction + function-exiting
+    // call
+    SmallVector<Instruction *, 16> ExitingInsts;
+    for (BasicBlock &BB : F) {
+      Instruction *TI = BB.getTerminator();
+      if (isa<ReturnInst>(TI))
+        ExitingInsts.push_back(TI);
+      // Any 'call' instruction with 'noreturn' attribute exits the function at
+      // this point. If this throws but unwinds to another EH pad within this
+      // function instead of exiting, this would have been an 'invoke', which
+      // happens if we use Wasm EH or Wasm SjLJ.
+      for (auto &I : BB) {
+        if (auto *CI = dyn_cast<CallInst>(&I)) {
+          bool IsNoReturn = CI->hasFnAttr(Attribute::NoReturn);
+          if (Function *CalleeF = CI->getCalledFunction())
+            IsNoReturn |= CalleeF->hasFnAttribute(Attribute::NoReturn);
+          if (IsNoReturn)
+            ExitingInsts.push_back(&I);
+        }
       }
     }
-  }
-  for (auto *I : ExitingInsts) {
-    DebugLoc DL = getOrCreateDebugLoc(I, F.getSubprogram());
-    // If this existing instruction is a call within a catchpad, we should add
-    // it as "funclet" to the operand bundle of 'free' call
-    SmallVector<OperandBundleDef, 1> Bundles;
-    if (auto *CB = dyn_cast<CallBase>(I))
-      if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
-        Bundles.push_back(OperandBundleDef(*Bundle));
-    IRB.SetInsertPoint(I);
-    auto *Free = IRB.CreateFree(SetjmpTable, Bundles);
-    Free->setDebugLoc(DL);
+    for (auto *I : ExitingInsts) {
+      DebugLoc DL = getOrCreateDebugLoc(I, F.getSubprogram());
+      // If this existing instruction is a call within a catchpad, we should add
+      // it as "funclet" to the operand bundle of 'free' call
+      SmallVector<OperandBundleDef, 1> Bundles;
+      if (auto *CB = dyn_cast<CallBase>(I))
+        if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
+          Bundles.push_back(OperandBundleDef(*Bundle));
+      IRB.SetInsertPoint(I);
+      auto *Free = IRB.CreateFree(SetjmpTable, Bundles);
+      Free->setDebugLoc(DL);
+    }
   }
 
   // Every call to saveSetjmp can change setjmpTable and setjmpTableSize
@@ -1738,10 +1776,16 @@ void WebAssemblyLowerEmscriptenEHSjLj::handleLongjmpableCallsForWasmSjLj(
   BasicBlock *ThenBB = BasicBlock::Create(C, "if.then", &F);
   BasicBlock *EndBB = BasicBlock::Create(C, "if.end", &F);
   Value *EnvP = IRB.CreateBitCast(Env, getAddrPtrType(&M), "env.p");
-  Value *SetjmpID = IRB.CreateLoad(getAddrIntType(&M), EnvP, "setjmp.id");
-  Value *Label =
-      IRB.CreateCall(TestSetjmpF, {SetjmpID, SetjmpTable, SetjmpTableSize},
-                     OperandBundleDef("funclet", CatchPad), "label");
+  Value *Label;
+  if (EnableWasmAltSjLj) {
+    Label = IRB.CreateCall(TestSetjmpF, {EnvP, SetjmpTable},
+                           OperandBundleDef("funclet", CatchPad), "label");
+  } else {
+    Value *SetjmpID = IRB.CreateLoad(getAddrIntType(&M), EnvP, "setjmp.id");
+    Label =
+        IRB.CreateCall(TestSetjmpF, {SetjmpID, SetjmpTable, SetjmpTableSize},
+                       OperandBundleDef("funclet", CatchPad), "label");
+  }
   Value *Cmp = IRB.CreateICmpEQ(Label, IRB.getInt32(0));
   IRB.CreateCondBr(Cmp, ThenBB, EndBB);
 
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp
index 70685b2e3bb2de..6db019034028bc 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp
@@ -370,6 +370,7 @@ FunctionPass *WebAssemblyPassConfig::createTargetRegisterAllocator(bool) {
   return nullptr; // No reg alloc
 }
 
+using WebAssembly::WasmEnableAltSjLj;
 using WebAssembly::WasmEnableEH;
 using WebAssembly::WasmEnableEmEH;
 using WebAssembly::WasmEnableEmSjLj;
@@ -405,6 +406,9 @@ static void basicCheckForEHAndSjLj(TargetMachine *TM) {
     report_fatal_error(
         "-exception-model=wasm only allowed with at least one of "
         "-wasm-enable-eh or -wasm-enable-sjlj");
+  if (!WasmEnableSjLj && WasmEnableAltSjLj)
+    report_fatal_error("-experimental-wasm-enable-alt-sjlj only allowed with "
+                       "-wasm-enable-sjlj");
 
   // You can't enable two modes of EH at the same time
   if (WasmEnableEmEH && WasmEnableEH)



More information about the cfe-commits mailing list