[clang] [llvm] [WebAssembly] Implement an alternative translation for -wasm-enable-sjlj (PR #84137)
YAMAMOTO Takashi via cfe-commits
cfe-commits at lists.llvm.org
Wed Mar 6 00:14:31 PST 2024
https://github.com/yamt created https://github.com/llvm/llvm-project/pull/84137
Instead of maintaining per-function-invocation malloc()'ed tables to track which functions each label belongs to, store the equivalent info in jump buffers (jmp_buf) themselves.
Also, use a less emscripten-looking ABI symbols:
saveSetjmp -> __wasm_sjlj_setjmp
testSetjmp -> __wasm_sjlj_test
getTempRet0 -> (removed)
__wasm_longjmp -> __wasm_sjlj_longjmp
Enabled with:
-mllvm -wasm-enable-sjlj -mllvm -experimental-wasm-enable-alt-sjlj
(-experimental-wasm-enable-alt-sjlj is the new option this change introduces.)
While I want to use this for WASI, it should work for emscripten as well.
An example runtime and a few tests:
https://github.com/yamt/garbage/tree/wasm-sjlj-alt2/wasm/longjmp
Discussion:
https://docs.google.com/document/d/1ZvTPT36K5jjiedF8MCXbEmYjULJjI723aOAks1IdLLg/edit
>From 1283ae6b5536810f8fbe183eda80997aa9f5cdc3 Mon Sep 17 00:00:00 2001
From: YAMAMOTO Takashi <yamamoto at midokura.com>
Date: Fri, 9 Feb 2024 15:49:55 +0900
Subject: [PATCH] [WebAssembly] Implement an alternative translation for
-wasm-enable-sjlj
Instead of maintaining per-function-invocation malloc()'ed tables to
track which functions each label belongs to, store the equivalent info
in jump buffers (jmp_buf) themselves.
Also, use a less emscripten-looking ABI symbols:
saveSetjmp -> __wasm_sjlj_setjmp
testSetjmp -> __wasm_sjlj_test
getTempRet0 -> (removed)
__wasm_longjmp -> __wasm_sjlj_longjmp
Enabled with:
-mllvm -wasm-enable-sjlj -mllvm -experimental-wasm-enable-alt-sjlj
(-experimental-wasm-enable-alt-sjlj is the new option this change
introduces.)
While I want to use this for WASI, it should work for emscripten as well.
An example runtime and a few tests:
https://github.com/yamt/garbage/tree/wasm-sjlj-alt2/wasm/longjmp
Discussion:
https://docs.google.com/document/d/1ZvTPT36K5jjiedF8MCXbEmYjULJjI723aOAks1IdLLg/edit
---
clang/lib/Driver/ToolChains/WebAssembly.cpp | 14 ++
.../MCTargetDesc/WebAssemblyMCTargetDesc.cpp | 3 +
.../MCTargetDesc/WebAssemblyMCTargetDesc.h | 1 +
.../WebAssemblyLowerEmscriptenEHSjLj.cpp | 174 +++++++++++-------
.../WebAssembly/WebAssemblyTargetMachine.cpp | 4 +
5 files changed, 131 insertions(+), 65 deletions(-)
diff --git a/clang/lib/Driver/ToolChains/WebAssembly.cpp b/clang/lib/Driver/ToolChains/WebAssembly.cpp
index b8c2573d6265fb..2e7c8e6e8d13f7 100644
--- a/clang/lib/Driver/ToolChains/WebAssembly.cpp
+++ b/clang/lib/Driver/ToolChains/WebAssembly.cpp
@@ -386,6 +386,20 @@ void WebAssembly::addClangTargetOptions(const ArgList &DriverArgs,
// Backend needs '-exception-model=wasm' to use Wasm EH instructions
CC1Args.push_back("-exception-model=wasm");
}
+
+ if (Opt.starts_with("-experimental-wasm-enable-alt-sjlj")) {
+ // '-mllvm -experimental-wasm-enable-alt-sjlj' should be used with
+ // '-mllvm -wasm-enable-sjlj'
+ bool HasWasmEnableSjlj = false;
+ for (const Arg *A : DriverArgs.filtered(options::OPT_mllvm)) {
+ if (StringRef(A->getValue(0)) == "-wasm-enable-sjlj")
+ HasWasmEnableSjlj = true;
+ }
+ if (!HasWasmEnableSjlj)
+ getDriver().Diag(diag::err_drv_argument_only_allowed_with)
+ << "-mllvm -experimental-wasm-enable-alt-sjlj"
+ << "-mllvm -wasm-enable-sjlj";
+ }
}
}
diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.cpp b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.cpp
index e8f58a19d25e3b..7f15742367be09 100644
--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.cpp
+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.cpp
@@ -54,6 +54,9 @@ cl::opt<bool>
// setjmp/longjmp handling using wasm EH instrutions
cl::opt<bool> WebAssembly::WasmEnableSjLj(
"wasm-enable-sjlj", cl::desc("WebAssembly setjmp/longjmp handling"));
+cl::opt<bool> WebAssembly::WasmEnableAltSjLj(
+ "experimental-wasm-enable-alt-sjlj",
+ cl::desc("Use experimental alternate ABI for --wasm-enable-sjlj"));
static MCAsmInfo *createMCAsmInfo(const MCRegisterInfo & /*MRI*/,
const Triple &TT,
diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
index 15aeaaeb8c4a4e..d23de9d407d894 100644
--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
@@ -44,6 +44,7 @@ extern cl::opt<bool> WasmEnableEmEH; // asm.js-style EH
extern cl::opt<bool> WasmEnableEmSjLj; // asm.js-style SjLJ
extern cl::opt<bool> WasmEnableEH; // EH using Wasm EH instructions
extern cl::opt<bool> WasmEnableSjLj; // SjLj using Wasm EH instructions
+extern cl::opt<bool> WasmEnableAltSjLj; // Alt ABI for WasmEnableSjLj
enum OperandType {
/// Basic block label in a branch construct.
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp
index 77e6640d5a8224..fc76757011f5d8 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp
@@ -300,6 +300,7 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass {
bool EnableEmEH; // Enable Emscripten exception handling
bool EnableEmSjLj; // Enable Emscripten setjmp/longjmp handling
bool EnableWasmSjLj; // Enable Wasm setjmp/longjmp handling
+ bool EnableWasmAltSjLj; // Alt ABI for EnableWasmSjLj
bool DoSjLj; // Whether we actually perform setjmp/longjmp handling
GlobalVariable *ThrewGV = nullptr; // __THREW__ (Emscripten)
@@ -368,7 +369,8 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass {
WebAssemblyLowerEmscriptenEHSjLj()
: ModulePass(ID), EnableEmEH(WebAssembly::WasmEnableEmEH),
EnableEmSjLj(WebAssembly::WasmEnableEmSjLj),
- EnableWasmSjLj(WebAssembly::WasmEnableSjLj) {
+ EnableWasmSjLj(WebAssembly::WasmEnableSjLj),
+ EnableWasmAltSjLj(WebAssembly::WasmEnableAltSjLj) {
assert(!(EnableEmSjLj && EnableWasmSjLj) &&
"Two SjLj modes cannot be turned on at the same time");
assert(!(EnableEmEH && EnableWasmSjLj) &&
@@ -619,6 +621,7 @@ static bool canLongjmp(const Value *Callee) {
// There are functions in Emscripten's JS glue code or compiler-rt
if (CalleeName == "__resumeException" || CalleeName == "llvm_eh_typeid_for" ||
CalleeName == "saveSetjmp" || CalleeName == "testSetjmp" ||
+ CalleeName == "__wasm_sjlj_setjmp" || CalleeName == "__wasm_sjlj_test" ||
CalleeName == "getTempRet0" || CalleeName == "setTempRet0")
return false;
@@ -999,7 +1002,11 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
// Register __wasm_longjmp function, which calls __builtin_wasm_longjmp.
FunctionType *FTy = FunctionType::get(
IRB.getVoidTy(), {Int8PtrTy, IRB.getInt32Ty()}, false);
- WasmLongjmpF = getEmscriptenFunction(FTy, "__wasm_longjmp", &M);
+ if (EnableWasmAltSjLj) {
+ WasmLongjmpF = getEmscriptenFunction(FTy, "__wasm_sjlj_longjmp", &M);
+ } else {
+ WasmLongjmpF = getEmscriptenFunction(FTy, "__wasm_longjmp", &M);
+ }
WasmLongjmpF->addFnAttr(Attribute::NoReturn);
}
@@ -1007,17 +1014,30 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
Type *Int8PtrTy = IRB.getPtrTy();
Type *Int32PtrTy = IRB.getPtrTy();
Type *Int32Ty = IRB.getInt32Ty();
- // Register saveSetjmp function
- FunctionType *SetjmpFTy = SetjmpF->getFunctionType();
- FunctionType *FTy = FunctionType::get(
- Int32PtrTy,
- {SetjmpFTy->getParamType(0), Int32Ty, Int32PtrTy, Int32Ty}, false);
- SaveSetjmpF = getEmscriptenFunction(FTy, "saveSetjmp", &M);
// Register testSetjmp function
- FTy = FunctionType::get(Int32Ty,
- {getAddrIntType(&M), Int32PtrTy, Int32Ty}, false);
- TestSetjmpF = getEmscriptenFunction(FTy, "testSetjmp", &M);
+ if (EnableWasmAltSjLj) {
+ // Register saveSetjmp function
+ FunctionType *SetjmpFTy = SetjmpF->getFunctionType();
+ FunctionType *FTy = FunctionType::get(
+ IRB.getVoidTy(), {SetjmpFTy->getParamType(0), Int32Ty, Int32PtrTy},
+ false);
+ SaveSetjmpF = getEmscriptenFunction(FTy, "__wasm_sjlj_setjmp", &M);
+
+ FTy = FunctionType::get(Int32Ty, {Int32PtrTy, Int32PtrTy}, false);
+ TestSetjmpF = getEmscriptenFunction(FTy, "__wasm_sjlj_test", &M);
+ } else {
+ // Register saveSetjmp function
+ FunctionType *SetjmpFTy = SetjmpF->getFunctionType();
+ FunctionType *FTy = FunctionType::get(
+ Int32PtrTy,
+ {SetjmpFTy->getParamType(0), Int32Ty, Int32PtrTy, Int32Ty}, false);
+ SaveSetjmpF = getEmscriptenFunction(FTy, "saveSetjmp", &M);
+
+ FTy = FunctionType::get(
+ Int32Ty, {getAddrIntType(&M), Int32PtrTy, Int32Ty}, false);
+ TestSetjmpF = getEmscriptenFunction(FTy, "testSetjmp", &M);
+ }
// wasm.catch() will be lowered down to wasm 'catch' instruction in
// instruction selection.
@@ -1291,19 +1311,29 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) {
Type *IntPtrTy = getAddrIntType(&M);
Constant *size = ConstantInt::get(IntPtrTy, 40);
IRB.SetInsertPoint(SetjmpTableSize);
- auto *SetjmpTable = IRB.CreateMalloc(IntPtrTy, IRB.getInt32Ty(), size,
- nullptr, nullptr, "setjmpTable");
- SetjmpTable->setDebugLoc(FirstDL);
- // CallInst::CreateMalloc may return a bitcast instruction if the result types
- // mismatch. We need to set the debug loc for the original call too.
- auto *MallocCall = SetjmpTable->stripPointerCasts();
- if (auto *MallocCallI = dyn_cast<Instruction>(MallocCall)) {
- MallocCallI->setDebugLoc(FirstDL);
+ Instruction *SetjmpTable;
+ if (EnableWasmAltSjLj) {
+ // This alloca'ed pointer is used by the runtime to identify function
+ // inovactions. It's just for pointer comparisons. It will never
+ // be dereferenced.
+ SetjmpTable = IRB.CreateAlloca(IRB.getInt32Ty());
+ SetjmpTable->setDebugLoc(FirstDL);
+ SetjmpTableInsts.push_back(SetjmpTable);
+ } else {
+ SetjmpTable = IRB.CreateMalloc(IntPtrTy, IRB.getInt32Ty(), size, nullptr,
+ nullptr, "setjmpTable");
+ SetjmpTable->setDebugLoc(FirstDL);
+ // CallInst::CreateMalloc may return a bitcast instruction if the result
+ // types mismatch. We need to set the debug loc for the original call too.
+ auto *MallocCall = SetjmpTable->stripPointerCasts();
+ if (auto *MallocCallI = dyn_cast<Instruction>(MallocCall)) {
+ MallocCallI->setDebugLoc(FirstDL);
+ }
+ // setjmpTable[0] = 0;
+ IRB.CreateStore(IRB.getInt32(0), SetjmpTable);
+ SetjmpTableInsts.push_back(SetjmpTable);
+ SetjmpTableSizeInsts.push_back(SetjmpTableSize);
}
- // setjmpTable[0] = 0;
- IRB.CreateStore(IRB.getInt32(0), SetjmpTable);
- SetjmpTableInsts.push_back(SetjmpTable);
- SetjmpTableSizeInsts.push_back(SetjmpTableSize);
// Setjmp transformation
SmallVector<PHINode *, 4> SetjmpRetPHIs;
@@ -1349,14 +1379,20 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) {
// Our index in the function is our place in the array + 1 to avoid index
// 0, because index 0 means the longjmp is not ours to handle.
IRB.SetInsertPoint(CI);
- Value *Args[] = {CI->getArgOperand(0), IRB.getInt32(SetjmpRetPHIs.size()),
- SetjmpTable, SetjmpTableSize};
- Instruction *NewSetjmpTable =
- IRB.CreateCall(SaveSetjmpF, Args, "setjmpTable");
- Instruction *NewSetjmpTableSize =
- IRB.CreateCall(GetTempRet0F, std::nullopt, "setjmpTableSize");
- SetjmpTableInsts.push_back(NewSetjmpTable);
- SetjmpTableSizeInsts.push_back(NewSetjmpTableSize);
+ if (EnableWasmAltSjLj) {
+ Value *Args[] = {CI->getArgOperand(0), IRB.getInt32(SetjmpRetPHIs.size()),
+ SetjmpTable};
+ IRB.CreateCall(SaveSetjmpF, Args);
+ } else {
+ Value *Args[] = {CI->getArgOperand(0), IRB.getInt32(SetjmpRetPHIs.size()),
+ SetjmpTable, SetjmpTableSize};
+ Instruction *NewSetjmpTable =
+ IRB.CreateCall(SaveSetjmpF, Args, "setjmpTable");
+ Instruction *NewSetjmpTableSize =
+ IRB.CreateCall(GetTempRet0F, std::nullopt, "setjmpTableSize");
+ SetjmpTableInsts.push_back(NewSetjmpTable);
+ SetjmpTableSizeInsts.push_back(NewSetjmpTableSize);
+ }
ToErase.push_back(CI);
}
@@ -1372,38 +1408,40 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) {
for (Instruction *I : ToErase)
I->eraseFromParent();
- // Free setjmpTable buffer before each return instruction + function-exiting
- // call
- SmallVector<Instruction *, 16> ExitingInsts;
- for (BasicBlock &BB : F) {
- Instruction *TI = BB.getTerminator();
- if (isa<ReturnInst>(TI))
- ExitingInsts.push_back(TI);
- // Any 'call' instruction with 'noreturn' attribute exits the function at
- // this point. If this throws but unwinds to another EH pad within this
- // function instead of exiting, this would have been an 'invoke', which
- // happens if we use Wasm EH or Wasm SjLJ.
- for (auto &I : BB) {
- if (auto *CI = dyn_cast<CallInst>(&I)) {
- bool IsNoReturn = CI->hasFnAttr(Attribute::NoReturn);
- if (Function *CalleeF = CI->getCalledFunction())
- IsNoReturn |= CalleeF->hasFnAttribute(Attribute::NoReturn);
- if (IsNoReturn)
- ExitingInsts.push_back(&I);
+ if (!EnableWasmAltSjLj) {
+ // Free setjmpTable buffer before each return instruction + function-exiting
+ // call
+ SmallVector<Instruction *, 16> ExitingInsts;
+ for (BasicBlock &BB : F) {
+ Instruction *TI = BB.getTerminator();
+ if (isa<ReturnInst>(TI))
+ ExitingInsts.push_back(TI);
+ // Any 'call' instruction with 'noreturn' attribute exits the function at
+ // this point. If this throws but unwinds to another EH pad within this
+ // function instead of exiting, this would have been an 'invoke', which
+ // happens if we use Wasm EH or Wasm SjLJ.
+ for (auto &I : BB) {
+ if (auto *CI = dyn_cast<CallInst>(&I)) {
+ bool IsNoReturn = CI->hasFnAttr(Attribute::NoReturn);
+ if (Function *CalleeF = CI->getCalledFunction())
+ IsNoReturn |= CalleeF->hasFnAttribute(Attribute::NoReturn);
+ if (IsNoReturn)
+ ExitingInsts.push_back(&I);
+ }
}
}
- }
- for (auto *I : ExitingInsts) {
- DebugLoc DL = getOrCreateDebugLoc(I, F.getSubprogram());
- // If this existing instruction is a call within a catchpad, we should add
- // it as "funclet" to the operand bundle of 'free' call
- SmallVector<OperandBundleDef, 1> Bundles;
- if (auto *CB = dyn_cast<CallBase>(I))
- if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
- Bundles.push_back(OperandBundleDef(*Bundle));
- IRB.SetInsertPoint(I);
- auto *Free = IRB.CreateFree(SetjmpTable, Bundles);
- Free->setDebugLoc(DL);
+ for (auto *I : ExitingInsts) {
+ DebugLoc DL = getOrCreateDebugLoc(I, F.getSubprogram());
+ // If this existing instruction is a call within a catchpad, we should add
+ // it as "funclet" to the operand bundle of 'free' call
+ SmallVector<OperandBundleDef, 1> Bundles;
+ if (auto *CB = dyn_cast<CallBase>(I))
+ if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
+ Bundles.push_back(OperandBundleDef(*Bundle));
+ IRB.SetInsertPoint(I);
+ auto *Free = IRB.CreateFree(SetjmpTable, Bundles);
+ Free->setDebugLoc(DL);
+ }
}
// Every call to saveSetjmp can change setjmpTable and setjmpTableSize
@@ -1738,10 +1776,16 @@ void WebAssemblyLowerEmscriptenEHSjLj::handleLongjmpableCallsForWasmSjLj(
BasicBlock *ThenBB = BasicBlock::Create(C, "if.then", &F);
BasicBlock *EndBB = BasicBlock::Create(C, "if.end", &F);
Value *EnvP = IRB.CreateBitCast(Env, getAddrPtrType(&M), "env.p");
- Value *SetjmpID = IRB.CreateLoad(getAddrIntType(&M), EnvP, "setjmp.id");
- Value *Label =
- IRB.CreateCall(TestSetjmpF, {SetjmpID, SetjmpTable, SetjmpTableSize},
- OperandBundleDef("funclet", CatchPad), "label");
+ Value *Label;
+ if (EnableWasmAltSjLj) {
+ Label = IRB.CreateCall(TestSetjmpF, {EnvP, SetjmpTable},
+ OperandBundleDef("funclet", CatchPad), "label");
+ } else {
+ Value *SetjmpID = IRB.CreateLoad(getAddrIntType(&M), EnvP, "setjmp.id");
+ Label =
+ IRB.CreateCall(TestSetjmpF, {SetjmpID, SetjmpTable, SetjmpTableSize},
+ OperandBundleDef("funclet", CatchPad), "label");
+ }
Value *Cmp = IRB.CreateICmpEQ(Label, IRB.getInt32(0));
IRB.CreateCondBr(Cmp, ThenBB, EndBB);
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp
index 70685b2e3bb2de..6db019034028bc 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp
@@ -370,6 +370,7 @@ FunctionPass *WebAssemblyPassConfig::createTargetRegisterAllocator(bool) {
return nullptr; // No reg alloc
}
+using WebAssembly::WasmEnableAltSjLj;
using WebAssembly::WasmEnableEH;
using WebAssembly::WasmEnableEmEH;
using WebAssembly::WasmEnableEmSjLj;
@@ -405,6 +406,9 @@ static void basicCheckForEHAndSjLj(TargetMachine *TM) {
report_fatal_error(
"-exception-model=wasm only allowed with at least one of "
"-wasm-enable-eh or -wasm-enable-sjlj");
+ if (!WasmEnableSjLj && WasmEnableAltSjLj)
+ report_fatal_error("-experimental-wasm-enable-alt-sjlj only allowed with "
+ "-wasm-enable-sjlj");
// You can't enable two modes of EH at the same time
if (WasmEnableEmEH && WasmEnableEH)
More information about the cfe-commits
mailing list