[llvm] [WebAssembly] Support type checker for new EH (PR #111069)

Heejin Ahn via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 3 15:32:06 PDT 2024


https://github.com/aheejin updated https://github.com/llvm/llvm-project/pull/111069

>From 98122cdb6fb76fee9c4d6ee15c29b6e6dee7d235 Mon Sep 17 00:00:00 2001
From: Heejin Ahn <aheejin at gmail.com>
Date: Wed, 2 Oct 2024 01:19:47 +0000
Subject: [PATCH 1/2] [WebAssembly] Support type checker for new EH

This adds supports for the new EH instructions (`try_table` and
`throw_ref`) to the type checker.

One thing I'd like to improve on is the locations in the errors for
`catch_***` clauses. Currently they just point to the starting column of
`try_table` instruction itself. But to figure out where catch clauses
start you need to traverse `OperandVector` and check
`WebAssemblyOperand::isCatchList` on them to see which one is the catch
list operand, but `WebAssemblyOperand` class is in AsmParser and
AsmTypeCheck does not have access to it:
https://github.com/llvm/llvm-project/blob/cdfdc857cbab0418b7e5116fd4255eb5566588bd/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp#L43-L204
And even if AsmTypeCheck has access to it, currently it treats the list
of catch clauses as a single `WebAssemblyOperand` so there is no way to
get the starting location of each `catch_***` clause in the current
structure.

This also renames `valTypeToStackType` to `valTypesToStackTypes`, given
that it takes two type lists.
---
 .../AsmParser/WebAssemblyAsmTypeCheck.cpp     | 91 +++++++++++++++++--
 .../AsmParser/WebAssemblyAsmTypeCheck.h       |  7 +-
 llvm/test/MC/WebAssembly/eh-assembly.s        |  4 +-
 .../test/MC/WebAssembly/type-checker-errors.s | 23 +++++
 4 files changed, 112 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
index effc2e65223cad..f01e19962ab9fc 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
@@ -59,7 +59,7 @@ void WebAssemblyAsmTypeCheck::localDecl(
 }
 
 void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
-  LLVM_DEBUG({ dbgs() << Msg << getTypesString(Stack, 0) << "\n"; });
+  LLVM_DEBUG({ dbgs() << Msg << getTypesString(Stack) << "\n"; });
 }
 
 bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
@@ -116,8 +116,15 @@ std::string WebAssemblyAsmTypeCheck::getTypesString(ArrayRef<StackType> Types,
   return SS.str();
 }
 
+std::string
+WebAssemblyAsmTypeCheck::getTypesString(ArrayRef<wasm::ValType> Types,
+                                        size_t StartPos) {
+  return getTypesString(valTypesToStackTypes(Types), StartPos);
+}
+
 SmallVector<WebAssemblyAsmTypeCheck::StackType, 4>
-WebAssemblyAsmTypeCheck::valTypeToStackType(ArrayRef<wasm::ValType> ValTypes) {
+WebAssemblyAsmTypeCheck::valTypesToStackTypes(
+    ArrayRef<wasm::ValType> ValTypes) {
   SmallVector<StackType, 4> Types(ValTypes.size());
   std::transform(ValTypes.begin(), ValTypes.end(), Types.begin(),
                  [](wasm::ValType Val) -> StackType { return Val; });
@@ -127,7 +134,7 @@ WebAssemblyAsmTypeCheck::valTypeToStackType(ArrayRef<wasm::ValType> ValTypes) {
 bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
                                          ArrayRef<wasm::ValType> ValTypes,
                                          bool ExactMatch) {
-  return checkTypes(ErrorLoc, valTypeToStackType(ValTypes), ExactMatch);
+  return checkTypes(ErrorLoc, valTypesToStackTypes(ValTypes), ExactMatch);
 }
 
 bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
@@ -178,14 +185,14 @@ bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
                            : std::max((int)BlockStackStartPos,
                                       (int)Stack.size() - (int)Types.size());
   return typeError(ErrorLoc, "type mismatch, expected " +
-                                 getTypesString(Types, 0) + " but got " +
+                                 getTypesString(Types) + " but got " +
                                  getTypesString(Stack, StackStartPos));
 }
 
 bool WebAssemblyAsmTypeCheck::popTypes(SMLoc ErrorLoc,
                                        ArrayRef<wasm::ValType> ValTypes,
                                        bool ExactMatch) {
-  return popTypes(ErrorLoc, valTypeToStackType(ValTypes), ExactMatch);
+  return popTypes(ErrorLoc, valTypesToStackTypes(ValTypes), ExactMatch);
 }
 
 bool WebAssemblyAsmTypeCheck::popTypes(SMLoc ErrorLoc,
@@ -215,7 +222,7 @@ bool WebAssemblyAsmTypeCheck::popAnyType(SMLoc ErrorLoc) {
 }
 
 void WebAssemblyAsmTypeCheck::pushTypes(ArrayRef<wasm::ValType> ValTypes) {
-  Stack.append(valTypeToStackType(ValTypes));
+  Stack.append(valTypesToStackTypes(ValTypes));
 }
 
 bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
@@ -322,6 +329,63 @@ bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc, bool ExactMatch) {
   return checkTypes(ErrorLoc, FuncInfo.Sig.Returns, ExactMatch);
 }
 
+// Unlike checkTypes() family, this just compare the equivalence of the two
+// ValType vectors
+static bool compareTypes(ArrayRef<wasm::ValType> TypesA,
+                         ArrayRef<wasm::ValType> TypesB) {
+  if (TypesA.size() != TypesB.size())
+    return true;
+  for (size_t I = 0, E = TypesA.size(); I < E; I++)
+    if (TypesA[I] != TypesB[I])
+      return true;
+  return false;
+}
+
+bool WebAssemblyAsmTypeCheck::checkTryTable(SMLoc ErrorLoc,
+                                            const MCInst &Inst) {
+  bool Error = false;
+  unsigned OpIdx = 1; // OpIdx 0 is the block type
+  int64_t NumCatches = Inst.getOperand(OpIdx++).getImm();
+  for (int64_t I = 0; I < NumCatches; I++) {
+    int64_t Opcode = Inst.getOperand(OpIdx++).getImm();
+    std::string ErrorMsgBase =
+        "try_table: catch index " + std::to_string(I) + ": ";
+
+    const wasm::WasmSignature *Sig = nullptr;
+    SmallVector<wasm::ValType> SentTypes;
+    if (Opcode == wasm::WASM_OPCODE_CATCH ||
+        Opcode == wasm::WASM_OPCODE_CATCH_REF) {
+      if (!getSignature(ErrorLoc, Inst.getOperand(OpIdx++),
+                        wasm::WASM_SYMBOL_TYPE_TAG, Sig))
+        SentTypes.insert(SentTypes.end(), Sig->Params.begin(),
+                         Sig->Params.end());
+      else
+        Error = true;
+    }
+    if (Opcode == wasm::WASM_OPCODE_CATCH_REF ||
+        Opcode == wasm::WASM_OPCODE_CATCH_ALL_REF) {
+      SentTypes.push_back(wasm::ValType::EXNREF);
+    }
+
+    unsigned Level = Inst.getOperand(OpIdx++).getImm();
+    if (Level < BlockInfoStack.size()) {
+      const auto &DestBlockInfo =
+          BlockInfoStack[BlockInfoStack.size() - Level - 1];
+      if (compareTypes(SentTypes, DestBlockInfo.Sig.Returns)) {
+        std::string ErrorMsg =
+            ErrorMsgBase + "type mismatch, catch tag type is " +
+            getTypesString(SentTypes) + ", but destination's return type is " +
+            getTypesString(DestBlockInfo.Sig.Returns);
+        Error |= typeError(ErrorLoc, ErrorMsg);
+      }
+    } else {
+      Error = typeError(ErrorLoc, ErrorMsgBase + "invalid depth " +
+                                      std::to_string(Level));
+    }
+  }
+  return Error;
+}
+
 bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
                                         OperandVector &Operands) {
   auto Opc = Inst.getOpcode();
@@ -460,10 +524,13 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
     return popType(ErrorLoc, Any{});
   }
 
-  if (Name == "block" || Name == "loop" || Name == "if" || Name == "try") {
+  if (Name == "block" || Name == "loop" || Name == "if" || Name == "try" ||
+      Name == "try_table") {
     bool Error = Name == "if" && popType(ErrorLoc, wasm::ValType::I32);
     // Pop block input parameters and check their types are correct
     Error |= popTypes(ErrorLoc, LastSig.Params);
+    if (Name == "try_table")
+      Error |= checkTryTable(ErrorLoc, Inst);
     // Push a new block info
     BlockInfoStack.push_back({LastSig, Stack.size(), Name == "loop"});
     // Push back block input parameters
@@ -472,8 +539,8 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
   }
 
   if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
-      Name == "end_try" || Name == "delegate" || Name == "else" ||
-      Name == "catch" || Name == "catch_all") {
+      Name == "end_try" || Name == "delegate" || Name == "end_try_table" ||
+      Name == "else" || Name == "catch" || Name == "catch_all") {
     assert(!BlockInfoStack.empty());
     // Check if the types on the stack match with the block return type
     const auto &LastBlockInfo = BlockInfoStack.back();
@@ -586,6 +653,12 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
     return Error;
   }
 
+  if (Name == "throw_ref") {
+    bool Error = popType(ErrorLoc, wasm::ValType::EXNREF);
+    pushType(Polymorphic{});
+    return Error;
+  }
+
   // The current instruction is a stack instruction which doesn't have
   // explicit operands that indicate push/pop types, so we get those from
   // the register version of the same instruction.
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
index 596fb27bce94e6..e6fddf98060265 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
@@ -65,9 +65,11 @@ class WebAssemblyAsmTypeCheck final {
   void pushTypes(ArrayRef<wasm::ValType> Types);
   void pushType(StackType Type) { Stack.push_back(Type); }
   bool match(StackType TypeA, StackType TypeB);
-  std::string getTypesString(ArrayRef<StackType> Types, size_t StartPos);
+  std::string getTypesString(ArrayRef<wasm::ValType> Types,
+                             size_t StartPos = 0);
+  std::string getTypesString(ArrayRef<StackType> Types, size_t StartPos = 0);
   SmallVector<StackType, 4>
-  valTypeToStackType(ArrayRef<wasm::ValType> ValTypes);
+  valTypesToStackTypes(ArrayRef<wasm::ValType> ValTypes);
 
   void dumpTypeStack(Twine Msg);
   bool typeError(SMLoc ErrorLoc, const Twine &Msg);
@@ -80,6 +82,7 @@ class WebAssemblyAsmTypeCheck final {
   bool getTable(SMLoc ErrorLoc, const MCOperand &TableOp, wasm::ValType &Type);
   bool getSignature(SMLoc ErrorLoc, const MCOperand &SigOp,
                     wasm::WasmSymbolType Type, const wasm::WasmSignature *&Sig);
+  bool checkTryTable(SMLoc ErrorLoc, const MCInst &Inst);
 
 public:
   WebAssemblyAsmTypeCheck(MCAsmParser &Parser, const MCInstrInfo &MII,
diff --git a/llvm/test/MC/WebAssembly/eh-assembly.s b/llvm/test/MC/WebAssembly/eh-assembly.s
index b4d6b324d96e3e..a03c1b8e1aed14 100644
--- a/llvm/test/MC/WebAssembly/eh-assembly.s
+++ b/llvm/test/MC/WebAssembly/eh-assembly.s
@@ -1,6 +1,6 @@
-# RUN: llvm-mc -triple=wasm32-unknown-unknown -mattr=+exception-handling --no-type-check < %s | FileCheck %s
+# RUN: llvm-mc -triple=wasm32-unknown-unknown -mattr=+exception-handling < %s | FileCheck %s
 # Check that it converts to .o without errors, but don't check any output:
-# RUN: llvm-mc -triple=wasm32-unknown-unknown -filetype=obj -mattr=+exception-handling --no-type-check -o %t.o < %s
+# RUN: llvm-mc -triple=wasm32-unknown-unknown -filetype=obj -mattr=+exception-handling -o %t.o < %s
 
   .tagtype  __cpp_exception i32
   .tagtype  __c_longjmp i32
diff --git a/llvm/test/MC/WebAssembly/type-checker-errors.s b/llvm/test/MC/WebAssembly/type-checker-errors.s
index df537a9ba5d0a0..74ab17fdefdad9 100644
--- a/llvm/test/MC/WebAssembly/type-checker-errors.s
+++ b/llvm/test/MC/WebAssembly/type-checker-errors.s
@@ -944,3 +944,26 @@ block_param_and_return:
 
 # CHECK: :[[@LINE+1]]:3: error: type mismatch, expected [] but got [f32]
   end_function
+
+  .tagtype  __cpp_exception i32
+
+eh_test:
+  .functype eh_test () -> ()
+  block i32
+    block i32
+      block i32
+        block
+# CHECK: :[[@LINE+4]]:11: error: try_table: catch index 0: type mismatch, catch tag type is [i32], but destination's return type is []
+# CHECK: :[[@LINE+3]]:11: error: try_table: catch index 1: type mismatch, catch tag type is [i32, exnref], but destination's return type is [i32]
+# CHECK: :[[@LINE+2]]:11: error: try_table: catch index 2: type mismatch, catch tag type is [], but destination's return type is [i32]
+# CHECK: :[[@LINE+1]]:11: error: try_table: catch index 3: type mismatch, catch tag type is [exnref], but destination's return type is [i32]
+          try_table i32 (catch __cpp_exception 0) (catch_ref __cpp_exception 1) (catch_all 2) (catch_all_ref 3)
+# CHECK: :[[@LINE+1]]:11: error: type mismatch, expected [i32] but got []
+          end_try_table
+# CHECK: :[[@LINE+1]]:9: error: type mismatch, expected [] but got [i32]
+        end_block
+      end_block
+    end_block
+  end_block
+  drop
+  end_function

>From 6068e77a6c6835ef23d3806d32ff85aec672daf7 Mon Sep 17 00:00:00 2001
From: Heejin Ahn <aheejin at gmail.com>
Date: Thu, 3 Oct 2024 22:31:46 +0000
Subject: [PATCH 2/2] throw_ref

---
 llvm/test/MC/WebAssembly/eh-assembly.s | 2 --
 1 file changed, 2 deletions(-)

diff --git a/llvm/test/MC/WebAssembly/eh-assembly.s b/llvm/test/MC/WebAssembly/eh-assembly.s
index a03c1b8e1aed14..d28a7048fb6114 100644
--- a/llvm/test/MC/WebAssembly/eh-assembly.s
+++ b/llvm/test/MC/WebAssembly/eh-assembly.s
@@ -24,7 +24,6 @@ eh_test:
         return
       end_block
       throw_ref
-      drop
     end_block
     return
   end_block
@@ -101,7 +100,6 @@ eh_test:
 # CHECK-NEXT:    return
 # CHECK-NEXT:    end_block
 # CHECK-NEXT:    throw_ref
-# CHECK-NEXT:    drop
 # CHECK-NEXT:    end_block
 # CHECK-NEXT:    return
 # CHECK-NEXT:    end_block



More information about the llvm-commits mailing list