[llvm] [WebAssembly] Support type checker for new EH (PR #111069)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 3 15:29:09 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-webassembly
Author: Heejin Ahn (aheejin)
<details>
<summary>Changes</summary>
This adds supports for the new EH instructions (`try_table` and `throw_ref`) to the type checker.
One thing I'd like to improve on is the locations in the errors for `catch_***` clauses. Currently they just point to the starting column of `try_table` instruction itself. But to figure out where catch clauses start you need to traverse `OperandVector` and check `WebAssemblyOperand::isCatchList` on them to see which one is the catch list operand, but `WebAssemblyOperand` class is in AsmParser and AsmTypeCheck does not have access to it:
https://github.com/llvm/llvm-project/blob/cdfdc857cbab0418b7e5116fd4255eb5566588bd/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp#L43-L204 And even if AsmTypeCheck has access to it, currently it treats the list of catch clauses as a single `WebAssemblyOperand` so there is no way to get the starting location of each `catch_***` clause in the current structure.
This also renames `valTypeToStackType` to `valTypesToStackTypes`, given that it takes two type lists.
---
Full diff: https://github.com/llvm/llvm-project/pull/111069.diff
4 Files Affected:
- (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp (+82-9)
- (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h (+5-2)
- (modified) llvm/test/MC/WebAssembly/eh-assembly.s (+2-2)
- (modified) llvm/test/MC/WebAssembly/type-checker-errors.s (+23)
``````````diff
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
index effc2e65223cad..f01e19962ab9fc 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
@@ -59,7 +59,7 @@ void WebAssemblyAsmTypeCheck::localDecl(
}
void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
- LLVM_DEBUG({ dbgs() << Msg << getTypesString(Stack, 0) << "\n"; });
+ LLVM_DEBUG({ dbgs() << Msg << getTypesString(Stack) << "\n"; });
}
bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
@@ -116,8 +116,15 @@ std::string WebAssemblyAsmTypeCheck::getTypesString(ArrayRef<StackType> Types,
return SS.str();
}
+std::string
+WebAssemblyAsmTypeCheck::getTypesString(ArrayRef<wasm::ValType> Types,
+ size_t StartPos) {
+ return getTypesString(valTypesToStackTypes(Types), StartPos);
+}
+
SmallVector<WebAssemblyAsmTypeCheck::StackType, 4>
-WebAssemblyAsmTypeCheck::valTypeToStackType(ArrayRef<wasm::ValType> ValTypes) {
+WebAssemblyAsmTypeCheck::valTypesToStackTypes(
+ ArrayRef<wasm::ValType> ValTypes) {
SmallVector<StackType, 4> Types(ValTypes.size());
std::transform(ValTypes.begin(), ValTypes.end(), Types.begin(),
[](wasm::ValType Val) -> StackType { return Val; });
@@ -127,7 +134,7 @@ WebAssemblyAsmTypeCheck::valTypeToStackType(ArrayRef<wasm::ValType> ValTypes) {
bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
ArrayRef<wasm::ValType> ValTypes,
bool ExactMatch) {
- return checkTypes(ErrorLoc, valTypeToStackType(ValTypes), ExactMatch);
+ return checkTypes(ErrorLoc, valTypesToStackTypes(ValTypes), ExactMatch);
}
bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
@@ -178,14 +185,14 @@ bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
: std::max((int)BlockStackStartPos,
(int)Stack.size() - (int)Types.size());
return typeError(ErrorLoc, "type mismatch, expected " +
- getTypesString(Types, 0) + " but got " +
+ getTypesString(Types) + " but got " +
getTypesString(Stack, StackStartPos));
}
bool WebAssemblyAsmTypeCheck::popTypes(SMLoc ErrorLoc,
ArrayRef<wasm::ValType> ValTypes,
bool ExactMatch) {
- return popTypes(ErrorLoc, valTypeToStackType(ValTypes), ExactMatch);
+ return popTypes(ErrorLoc, valTypesToStackTypes(ValTypes), ExactMatch);
}
bool WebAssemblyAsmTypeCheck::popTypes(SMLoc ErrorLoc,
@@ -215,7 +222,7 @@ bool WebAssemblyAsmTypeCheck::popAnyType(SMLoc ErrorLoc) {
}
void WebAssemblyAsmTypeCheck::pushTypes(ArrayRef<wasm::ValType> ValTypes) {
- Stack.append(valTypeToStackType(ValTypes));
+ Stack.append(valTypesToStackTypes(ValTypes));
}
bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
@@ -322,6 +329,63 @@ bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc, bool ExactMatch) {
return checkTypes(ErrorLoc, FuncInfo.Sig.Returns, ExactMatch);
}
+// Unlike checkTypes() family, this just compare the equivalence of the two
+// ValType vectors
+static bool compareTypes(ArrayRef<wasm::ValType> TypesA,
+ ArrayRef<wasm::ValType> TypesB) {
+ if (TypesA.size() != TypesB.size())
+ return true;
+ for (size_t I = 0, E = TypesA.size(); I < E; I++)
+ if (TypesA[I] != TypesB[I])
+ return true;
+ return false;
+}
+
+bool WebAssemblyAsmTypeCheck::checkTryTable(SMLoc ErrorLoc,
+ const MCInst &Inst) {
+ bool Error = false;
+ unsigned OpIdx = 1; // OpIdx 0 is the block type
+ int64_t NumCatches = Inst.getOperand(OpIdx++).getImm();
+ for (int64_t I = 0; I < NumCatches; I++) {
+ int64_t Opcode = Inst.getOperand(OpIdx++).getImm();
+ std::string ErrorMsgBase =
+ "try_table: catch index " + std::to_string(I) + ": ";
+
+ const wasm::WasmSignature *Sig = nullptr;
+ SmallVector<wasm::ValType> SentTypes;
+ if (Opcode == wasm::WASM_OPCODE_CATCH ||
+ Opcode == wasm::WASM_OPCODE_CATCH_REF) {
+ if (!getSignature(ErrorLoc, Inst.getOperand(OpIdx++),
+ wasm::WASM_SYMBOL_TYPE_TAG, Sig))
+ SentTypes.insert(SentTypes.end(), Sig->Params.begin(),
+ Sig->Params.end());
+ else
+ Error = true;
+ }
+ if (Opcode == wasm::WASM_OPCODE_CATCH_REF ||
+ Opcode == wasm::WASM_OPCODE_CATCH_ALL_REF) {
+ SentTypes.push_back(wasm::ValType::EXNREF);
+ }
+
+ unsigned Level = Inst.getOperand(OpIdx++).getImm();
+ if (Level < BlockInfoStack.size()) {
+ const auto &DestBlockInfo =
+ BlockInfoStack[BlockInfoStack.size() - Level - 1];
+ if (compareTypes(SentTypes, DestBlockInfo.Sig.Returns)) {
+ std::string ErrorMsg =
+ ErrorMsgBase + "type mismatch, catch tag type is " +
+ getTypesString(SentTypes) + ", but destination's return type is " +
+ getTypesString(DestBlockInfo.Sig.Returns);
+ Error |= typeError(ErrorLoc, ErrorMsg);
+ }
+ } else {
+ Error = typeError(ErrorLoc, ErrorMsgBase + "invalid depth " +
+ std::to_string(Level));
+ }
+ }
+ return Error;
+}
+
bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
OperandVector &Operands) {
auto Opc = Inst.getOpcode();
@@ -460,10 +524,13 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
return popType(ErrorLoc, Any{});
}
- if (Name == "block" || Name == "loop" || Name == "if" || Name == "try") {
+ if (Name == "block" || Name == "loop" || Name == "if" || Name == "try" ||
+ Name == "try_table") {
bool Error = Name == "if" && popType(ErrorLoc, wasm::ValType::I32);
// Pop block input parameters and check their types are correct
Error |= popTypes(ErrorLoc, LastSig.Params);
+ if (Name == "try_table")
+ Error |= checkTryTable(ErrorLoc, Inst);
// Push a new block info
BlockInfoStack.push_back({LastSig, Stack.size(), Name == "loop"});
// Push back block input parameters
@@ -472,8 +539,8 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
}
if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
- Name == "end_try" || Name == "delegate" || Name == "else" ||
- Name == "catch" || Name == "catch_all") {
+ Name == "end_try" || Name == "delegate" || Name == "end_try_table" ||
+ Name == "else" || Name == "catch" || Name == "catch_all") {
assert(!BlockInfoStack.empty());
// Check if the types on the stack match with the block return type
const auto &LastBlockInfo = BlockInfoStack.back();
@@ -586,6 +653,12 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
return Error;
}
+ if (Name == "throw_ref") {
+ bool Error = popType(ErrorLoc, wasm::ValType::EXNREF);
+ pushType(Polymorphic{});
+ return Error;
+ }
+
// The current instruction is a stack instruction which doesn't have
// explicit operands that indicate push/pop types, so we get those from
// the register version of the same instruction.
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
index 596fb27bce94e6..e6fddf98060265 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
@@ -65,9 +65,11 @@ class WebAssemblyAsmTypeCheck final {
void pushTypes(ArrayRef<wasm::ValType> Types);
void pushType(StackType Type) { Stack.push_back(Type); }
bool match(StackType TypeA, StackType TypeB);
- std::string getTypesString(ArrayRef<StackType> Types, size_t StartPos);
+ std::string getTypesString(ArrayRef<wasm::ValType> Types,
+ size_t StartPos = 0);
+ std::string getTypesString(ArrayRef<StackType> Types, size_t StartPos = 0);
SmallVector<StackType, 4>
- valTypeToStackType(ArrayRef<wasm::ValType> ValTypes);
+ valTypesToStackTypes(ArrayRef<wasm::ValType> ValTypes);
void dumpTypeStack(Twine Msg);
bool typeError(SMLoc ErrorLoc, const Twine &Msg);
@@ -80,6 +82,7 @@ class WebAssemblyAsmTypeCheck final {
bool getTable(SMLoc ErrorLoc, const MCOperand &TableOp, wasm::ValType &Type);
bool getSignature(SMLoc ErrorLoc, const MCOperand &SigOp,
wasm::WasmSymbolType Type, const wasm::WasmSignature *&Sig);
+ bool checkTryTable(SMLoc ErrorLoc, const MCInst &Inst);
public:
WebAssemblyAsmTypeCheck(MCAsmParser &Parser, const MCInstrInfo &MII,
diff --git a/llvm/test/MC/WebAssembly/eh-assembly.s b/llvm/test/MC/WebAssembly/eh-assembly.s
index b4d6b324d96e3e..a03c1b8e1aed14 100644
--- a/llvm/test/MC/WebAssembly/eh-assembly.s
+++ b/llvm/test/MC/WebAssembly/eh-assembly.s
@@ -1,6 +1,6 @@
-# RUN: llvm-mc -triple=wasm32-unknown-unknown -mattr=+exception-handling --no-type-check < %s | FileCheck %s
+# RUN: llvm-mc -triple=wasm32-unknown-unknown -mattr=+exception-handling < %s | FileCheck %s
# Check that it converts to .o without errors, but don't check any output:
-# RUN: llvm-mc -triple=wasm32-unknown-unknown -filetype=obj -mattr=+exception-handling --no-type-check -o %t.o < %s
+# RUN: llvm-mc -triple=wasm32-unknown-unknown -filetype=obj -mattr=+exception-handling -o %t.o < %s
.tagtype __cpp_exception i32
.tagtype __c_longjmp i32
diff --git a/llvm/test/MC/WebAssembly/type-checker-errors.s b/llvm/test/MC/WebAssembly/type-checker-errors.s
index df537a9ba5d0a0..74ab17fdefdad9 100644
--- a/llvm/test/MC/WebAssembly/type-checker-errors.s
+++ b/llvm/test/MC/WebAssembly/type-checker-errors.s
@@ -944,3 +944,26 @@ block_param_and_return:
# CHECK: :[[@LINE+1]]:3: error: type mismatch, expected [] but got [f32]
end_function
+
+ .tagtype __cpp_exception i32
+
+eh_test:
+ .functype eh_test () -> ()
+ block i32
+ block i32
+ block i32
+ block
+# CHECK: :[[@LINE+4]]:11: error: try_table: catch index 0: type mismatch, catch tag type is [i32], but destination's return type is []
+# CHECK: :[[@LINE+3]]:11: error: try_table: catch index 1: type mismatch, catch tag type is [i32, exnref], but destination's return type is [i32]
+# CHECK: :[[@LINE+2]]:11: error: try_table: catch index 2: type mismatch, catch tag type is [], but destination's return type is [i32]
+# CHECK: :[[@LINE+1]]:11: error: try_table: catch index 3: type mismatch, catch tag type is [exnref], but destination's return type is [i32]
+ try_table i32 (catch __cpp_exception 0) (catch_ref __cpp_exception 1) (catch_all 2) (catch_all_ref 3)
+# CHECK: :[[@LINE+1]]:11: error: type mismatch, expected [i32] but got []
+ end_try_table
+# CHECK: :[[@LINE+1]]:9: error: type mismatch, expected [] but got [i32]
+ end_block
+ end_block
+ end_block
+ end_block
+ drop
+ end_function
``````````
</details>
https://github.com/llvm/llvm-project/pull/111069
More information about the llvm-commits
mailing list