[llvm] [WebAssembly] Unify type checking in AsmTypeCheck (PR #110094)

via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 26 02:27:19 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-webassembly

@llvm/pr-subscribers-mc

Author: Heejin Ahn (aheejin)

<details>
<summary>Changes</summary>

This unifies the way we check types in various places in AsmTypeCheck. The objectives of this PR are:

- We now use `checkTypes` for all type checking and `checkAndPopTypes` for type checking + popping. All other functions are helper functions to call these two functions.

- We now support comparisons of types between vectors. This lets us printing error messages in more readable way. When an instruction takes [i32, i64] but the stack top is [f32, f64], now instead of
  ```console
  error: type mismatch, expected i64 but got f64
  error: type mismatch, expected i32 but got f32
  ```
  we can print this
  ```console
  error: type mismatch, expected [i32, i64] but got [f32, f64]
  ```
  which is also the format Wabt checker prints. This also helps printing more meaningful messages when there are superfluous values on the stack at the end of the function, such as:
  ```console
  error: type mismatch, expected [] but got [i32, exnref]
  ```
  Actually, many instructions are not utilizing this batch printing now, which still causes multiple error messages to be printed for a single instruction. This will be improved in a follow-up.

- The value stack now supports `Any` and `Ref`. There are instructions that requires the type to be anything. Also instructions like `ref.is_null` requires the type to be any reference types. Type comparison function will handle this types accordingly, meaning `match(I32, Any)` or `match(externref, Ref)` will succeed.

The changes in `type-checker-errors.s` are mostly the message format changes. One downside of the new message format is that it doesn't have instruction names in it. I plan to improve that in a potential follow-up.

This also made some modifications in the instructions in `type-checker-errors.s`. Currently, except for a few functions I've recently added at the end, each function tests for a single error, because the type checker used to bail out after the first error until #<!-- -->109705. But many functions included multiple errors anyway, which I don't think was the intention of the original writer. So I added some instructions to remove the other errors which are not being tested. (In some cases I added more error checking lines instead, when I felt that could be relevant.)

Thanks to the new `ExactMatch` option in `checkTypes` function family, we now can distinguish the cases when to check against only the top of the value stack and when to check against the whole stack (e.g. to check whether we have any superfluous values remaining at the end of the function). `return` or `return_call(_indirect)` can set `ExactMatch` to `false` because they don't care about the superfluous values. This makes `type-checker-return.s` succeed and I was able to remove the `FIXME`.

This is the basis of the PR that fixes block parameter/return type handling in the checker, but does not yet include the actual block-related functionality, which will be submitted separately after this PR.

---

Patch is 46.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110094.diff


6 Files Affected:

- (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp (+1-1) 
- (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp (+140-98) 
- (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h (+27-4) 
- (modified) llvm/test/MC/WebAssembly/basic-assembly.s (+10) 
- (modified) llvm/test/MC/WebAssembly/type-checker-errors.s (+126-94) 
- (modified) llvm/test/MC/WebAssembly/type-checker-return.s (-5) 


``````````diff
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
index 129fdaf37fc0d8..95db5500b0e1b1 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
@@ -1255,7 +1255,7 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
 
   void onEndOfFunction(SMLoc ErrorLoc) {
     if (!SkipTypeCheck)
-      TC.endOfFunction(ErrorLoc);
+      TC.endOfFunction(ErrorLoc, true);
     // Reset the type checker state.
     TC.clear();
   }
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
index 8b1e1dca4f8474..2f000354182fcb 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
@@ -33,6 +33,7 @@
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/SourceMgr.h"
+#include <sstream>
 
 using namespace llvm;
 
@@ -59,14 +60,7 @@ void WebAssemblyAsmTypeCheck::localDecl(
 }
 
 void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
-  LLVM_DEBUG({
-    std::string s;
-    for (auto VT : Stack) {
-      s += WebAssembly::typeToString(VT);
-      s += " ";
-    }
-    dbgs() << Msg << s << '\n';
-  });
+  LLVM_DEBUG({ dbgs() << Msg << getTypesString(Stack, 0); });
 }
 
 bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
@@ -77,34 +71,119 @@ bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
   return Parser.Error(ErrorLoc, Msg);
 }
 
-bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc,
-                                      std::optional<wasm::ValType> EVT) {
-  if (Stack.empty()) {
-    return typeError(ErrorLoc,
-                     EVT ? StringRef("empty stack while popping ") +
-                               WebAssembly::typeToString(*EVT)
-                         : StringRef("empty stack while popping value"));
+bool WebAssemblyAsmTypeCheck::match(StackType TypeA, StackType TypeB) {
+  if (TypeA == TypeB)
+    return false;
+  if (std::get_if<Any>(&TypeA) || std::get_if<Any>(&TypeB))
+    return false;
+
+  if (std::get_if<Ref>(&TypeB))
+    std::swap(TypeA, TypeB);
+  assert(std::get_if<wasm::ValType>(&TypeB));
+  if (std::get_if<Ref>(&TypeA) &&
+      WebAssembly::isRefType(std::get<wasm::ValType>(TypeB)))
+    return false;
+  return true;
+}
+
+std::string WebAssemblyAsmTypeCheck::getTypesString(ArrayRef<StackType> Types,
+                                                    size_t StartPos) {
+  SmallVector<std::string, 4> Reverse;
+  for (auto I = Types.size(); I > StartPos; I--) {
+    if (std::get_if<Any>(&Types[I - 1]))
+      Reverse.push_back("any");
+    else if (std::get_if<Ref>(&Types[I - 1]))
+      Reverse.push_back("ref");
+    else
+      Reverse.push_back(
+          WebAssembly::typeToString(std::get<wasm::ValType>(Types[I - 1])));
   }
-  auto PVT = Stack.pop_back_val();
-  if (EVT && *EVT != PVT) {
-    return typeError(ErrorLoc,
-                     StringRef("popped ") + WebAssembly::typeToString(PVT) +
-                         ", expected " + WebAssembly::typeToString(*EVT));
+
+  std::stringstream SS;
+  SS << "[";
+  bool First = true;
+  for (auto It = Reverse.rbegin(); It != Reverse.rend(); ++It) {
+    if (!First)
+      SS << ", ";
+    SS << *It;
+    First = false;
   }
-  return false;
+  SS << "]";
+  return SS.str();
 }
 
-bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
-  if (Stack.empty()) {
-    return typeError(ErrorLoc, StringRef("empty stack while popping reftype"));
-  }
-  auto PVT = Stack.pop_back_val();
-  if (!WebAssembly::isRefType(PVT)) {
-    return typeError(ErrorLoc, StringRef("popped ") +
-                                   WebAssembly::typeToString(PVT) +
-                                   ", expected reftype");
+SmallVector<WebAssemblyAsmTypeCheck::StackType, 4>
+WebAssemblyAsmTypeCheck::valTypeToStackType(ArrayRef<wasm::ValType> ValTypes) {
+  SmallVector<StackType, 4> Types(ValTypes.size());
+  std::transform(ValTypes.begin(), ValTypes.end(), Types.begin(),
+                 [](wasm::ValType Val) -> StackType { return Val; });
+  return Types;
+}
+
+bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
+                                         ArrayRef<wasm::ValType> ValTypes,
+                                         bool ExactMatch) {
+  return checkTypes(ErrorLoc, valTypeToStackType(ValTypes), ExactMatch);
+}
+
+bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
+                                         ArrayRef<StackType> Types,
+                                         bool ExactMatch) {
+  auto StackI = Stack.size();
+  auto TypeI = Types.size();
+  bool Error = false;
+  for (; StackI > 0 && TypeI > 0; StackI--, TypeI--) {
+    if (match(Stack[StackI - 1], Types[TypeI - 1])) {
+      Error = true;
+      break;
+    }
   }
-  return false;
+  if (TypeI > 0 || (ExactMatch && StackI > 0))
+    Error = true;
+
+  if (!Error)
+    return false;
+
+  auto StackStartPos =
+      ExactMatch ? 0 : std::max(0, (int)Stack.size() - (int)Types.size());
+  return typeError(ErrorLoc, "type mismatch, expected " +
+                                 getTypesString(Types, 0) + " but got " +
+                                 getTypesString(Stack, StackStartPos));
+}
+
+bool WebAssemblyAsmTypeCheck::checkAndPopTypes(SMLoc ErrorLoc,
+                                               ArrayRef<wasm::ValType> ValTypes,
+                                               bool ExactMatch) {
+  SmallVector<StackType, 4> Types(ValTypes.size());
+  std::transform(ValTypes.begin(), ValTypes.end(), Types.begin(),
+                 [](wasm::ValType Val) -> StackType { return Val; });
+  return checkAndPopTypes(ErrorLoc, Types, ExactMatch);
+}
+
+bool WebAssemblyAsmTypeCheck::checkAndPopTypes(SMLoc ErrorLoc,
+                                               ArrayRef<StackType> Types,
+                                               bool ExactMatch) {
+  bool Error = checkTypes(ErrorLoc, Types, ExactMatch);
+  auto NumPops = std::min(Stack.size(), Types.size());
+  for (size_t I = 0, E = NumPops; I != E; I++)
+    Stack.pop_back();
+  return Error;
+}
+
+bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc, StackType Type) {
+  return checkAndPopTypes(ErrorLoc, {Type}, false);
+}
+
+bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
+  return popType(ErrorLoc, Ref{});
+}
+
+bool WebAssemblyAsmTypeCheck::popAnyType(SMLoc ErrorLoc) {
+  return popType(ErrorLoc, Any{});
+}
+
+void WebAssemblyAsmTypeCheck::pushTypes(ArrayRef<wasm::ValType> ValTypes) {
+  Stack.append(valTypeToStackType(ValTypes));
 }
 
 bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
@@ -117,59 +196,29 @@ bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
   return false;
 }
 
-static std::optional<std::string>
-checkStackTop(const SmallVectorImpl<wasm::ValType> &ExpectedStackTop,
-              const SmallVectorImpl<wasm::ValType> &Got) {
-  for (size_t I = 0; I < ExpectedStackTop.size(); I++) {
-    auto EVT = ExpectedStackTop[I];
-    auto PVT = Got[Got.size() - ExpectedStackTop.size() + I];
-    if (PVT != EVT)
-      return std::string{"got "} + WebAssembly::typeToString(PVT) +
-             ", expected " + WebAssembly::typeToString(EVT);
-  }
-  return std::nullopt;
-}
-
 bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, size_t Level) {
   if (Level >= BrStack.size())
     return typeError(ErrorLoc,
                      StringRef("br: invalid depth ") + std::to_string(Level));
   const SmallVector<wasm::ValType, 4> &Expected =
       BrStack[BrStack.size() - Level - 1];
-  if (Expected.size() > Stack.size())
-    return typeError(ErrorLoc, "br: insufficient values on the type stack");
-  auto IsStackTopInvalid = checkStackTop(Expected, Stack);
-  if (IsStackTopInvalid)
-    return typeError(ErrorLoc, "br " + IsStackTopInvalid.value());
+  return checkTypes(ErrorLoc, Expected, false);
   return false;
 }
 
 bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) {
   if (!PopVals)
     BrStack.pop_back();
-  if (LastSig.Returns.size() > Stack.size())
-    return typeError(ErrorLoc, "end: insufficient values on the type stack");
 
-  if (PopVals) {
-    for (auto VT : llvm::reverse(LastSig.Returns)) {
-      if (popType(ErrorLoc, VT))
-        return true;
-    }
-    return false;
-  }
-
-  auto IsStackTopInvalid = checkStackTop(LastSig.Returns, Stack);
-  if (IsStackTopInvalid)
-    return typeError(ErrorLoc, "end " + IsStackTopInvalid.value());
-  return false;
+  if (PopVals)
+    return checkAndPopTypes(ErrorLoc, LastSig.Returns, false);
+  return checkTypes(ErrorLoc, LastSig.Returns, false);
 }
 
 bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
                                        const wasm::WasmSignature &Sig) {
-  bool Error = false;
-  for (auto VT : llvm::reverse(Sig.Params))
-    Error |= popType(ErrorLoc, VT);
-  Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end());
+  bool Error = checkAndPopTypes(ErrorLoc, Sig.Params, false);
+  pushTypes(Sig.Returns);
   return Error;
 }
 
@@ -246,7 +295,7 @@ bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc,
       TypeName = "tag";
       break;
     default:
-      return true;
+      assert(false);
     }
     return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
                                    ": missing ." + TypeName + "type");
@@ -254,15 +303,8 @@ bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc,
   return false;
 }
 
-bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
-  bool Error = false;
-  // Check the return types.
-  for (auto RVT : llvm::reverse(ReturnTypes))
-    Error |= popType(ErrorLoc, RVT);
-  if (!Stack.empty()) {
-    return typeError(ErrorLoc, std::to_string(Stack.size()) +
-                                   " superfluous return values");
-  }
+bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc, bool ExactMatch) {
+  bool Error = checkTypes(ErrorLoc, ReturnTypes, ExactMatch);
   Unreachable = true;
   return Error;
 }
@@ -276,7 +318,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
 
   if (Name == "local.get") {
     if (!getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type)) {
-      Stack.push_back(Type);
+      pushType(Type);
       return false;
     }
     return true;
@@ -291,7 +333,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
   if (Name == "local.tee") {
     if (!getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type)) {
       bool Error = popType(ErrorLoc, Type);
-      Stack.push_back(Type);
+      pushType(Type);
       return Error;
     }
     return true;
@@ -299,7 +341,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
 
   if (Name == "global.get") {
     if (!getGlobal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type)) {
-      Stack.push_back(Type);
+      pushType(Type);
       return false;
     }
     return true;
@@ -314,7 +356,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
   if (Name == "table.get") {
     bool Error = popType(ErrorLoc, wasm::ValType::I32);
     if (!getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type)) {
-      Stack.push_back(Type);
+      pushType(Type);
       return Error;
     }
     return true;
@@ -332,7 +374,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
 
   if (Name == "table.size") {
     bool Error = getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type);
-    Stack.push_back(wasm::ValType::I32);
+    pushType(wasm::ValType::I32);
     return Error;
   }
 
@@ -342,7 +384,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
       Error |= popType(ErrorLoc, Type);
     else
       Error = true;
-    Stack.push_back(wasm::ValType::I32);
+    pushType(wasm::ValType::I32);
     return Error;
   }
 
@@ -381,7 +423,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
   }
 
   if (Name == "drop") {
-    return popType(ErrorLoc, {});
+    return popType(ErrorLoc, Any{});
   }
 
   if (Name == "try" || Name == "block" || Name == "loop" || Name == "if") {
@@ -406,7 +448,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
                         wasm::WASM_SYMBOL_TYPE_TAG, Sig))
         // catch instruction pushes values whose types are specified in the
         // tag's "params" part
-        Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end());
+        pushTypes(Sig->Params);
       else
         Error = true;
     }
@@ -421,14 +463,14 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
   }
 
   if (Name == "return") {
-    return endOfFunction(ErrorLoc);
+    return endOfFunction(ErrorLoc, false);
   }
 
   if (Name == "call_indirect" || Name == "return_call_indirect") {
     // Function value.
     bool Error = popType(ErrorLoc, wasm::ValType::I32);
     Error |= checkSig(ErrorLoc, LastSig);
-    if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
+    if (Name == "return_call_indirect" && endOfFunction(ErrorLoc, false))
       return true;
     return Error;
   }
@@ -441,7 +483,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
       Error |= checkSig(ErrorLoc, *Sig);
     else
       Error = true;
-    if (Name == "return_call" && endOfFunction(ErrorLoc))
+    if (Name == "return_call" && endOfFunction(ErrorLoc, false))
       return true;
     return Error;
   }
@@ -453,7 +495,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
 
   if (Name == "ref.is_null") {
     bool Error = popRefType(ErrorLoc);
-    Stack.push_back(wasm::ValType::I32);
+    pushType(wasm::ValType::I32);
     return Error;
   }
 
@@ -471,22 +513,22 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
   auto RegOpc = WebAssembly::getRegisterOpcode(Opc);
   assert(RegOpc != -1 && "Failed to get register version of MC instruction");
   const auto &II = MII.get(RegOpc);
-  bool Error = false;
   // First pop all the uses off the stack and check them.
-  for (unsigned I = II.getNumOperands(); I > II.getNumDefs(); I--) {
-    const auto &Op = II.operands()[I - 1];
-    if (Op.OperandType == MCOI::OPERAND_REGISTER) {
-      auto VT = WebAssembly::regClassToValType(Op.RegClass);
-      Error |= popType(ErrorLoc, VT);
-    }
+  SmallVector<wasm::ValType, 4> PopTypes;
+  for (unsigned I = II.getNumDefs(); I < II.getNumOperands(); I++) {
+    const auto &Op = II.operands()[I];
+    if (Op.OperandType == MCOI::OPERAND_REGISTER)
+      PopTypes.push_back(WebAssembly::regClassToValType(Op.RegClass));
   }
+  bool Error = checkAndPopTypes(ErrorLoc, PopTypes, false);
+  SmallVector<wasm::ValType, 4> PushTypes;
   // Now push all the defs onto the stack.
   for (unsigned I = 0; I < II.getNumDefs(); I++) {
     const auto &Op = II.operands()[I];
     assert(Op.OperandType == MCOI::OPERAND_REGISTER && "Register expected");
-    auto VT = WebAssembly::regClassToValType(Op.RegClass);
-    Stack.push_back(VT);
+    PushTypes.push_back(WebAssembly::regClassToValType(Op.RegClass));
   }
+  pushTypes(PushTypes);
   return Error;
 }
 
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
index 972162d3e02f46..9fd35a26f30e50 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
@@ -21,6 +21,7 @@
 #include "llvm/MC/MCParser/MCAsmParser.h"
 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
 #include "llvm/MC/MCSymbol.h"
+#include <variant>
 
 namespace llvm {
 
@@ -28,7 +29,10 @@ class WebAssemblyAsmTypeCheck final {
   MCAsmParser &Parser;
   const MCInstrInfo &MII;
 
-  SmallVector<wasm::ValType, 8> Stack;
+  struct Ref : public std::monostate {};
+  struct Any : public std::monostate {};
+  using StackType = std::variant<wasm::ValType, Ref, Any>;
+  SmallVector<StackType, 16> Stack;
   SmallVector<SmallVector<wasm::ValType, 4>, 8> BrStack;
   SmallVector<wasm::ValType, 16> LocalTypes;
   SmallVector<wasm::ValType, 4> ReturnTypes;
@@ -36,10 +40,29 @@ class WebAssemblyAsmTypeCheck final {
   bool Unreachable = false;
   bool Is64;
 
+  // If ExactMatch is true, 'Types' will be compared against not only the top of
+  // the value stack but the whole remaining value stack
+  // (TODO: This should be the whole remaining value stack "at the the current
+  // block level", which has not been implemented yet)
+  bool checkTypes(SMLoc ErrorLoc, ArrayRef<wasm::ValType> Types,
+                  bool ExactMatch);
+  bool checkTypes(SMLoc ErrorLoc, ArrayRef<StackType> Types, bool ExactMatch);
+  bool checkAndPopTypes(SMLoc ErrorLoc, ArrayRef<wasm::ValType> Types,
+                        bool ExactMatch);
+  bool checkAndPopTypes(SMLoc ErrorLoc, ArrayRef<StackType> Types,
+                        bool ExactMatch);
+  bool popType(SMLoc ErrorLoc, StackType Type);
+  bool popRefType(SMLoc ErrorLoc);
+  bool popAnyType(SMLoc ErrorLoc);
+  void pushTypes(ArrayRef<wasm::ValType> Types);
+  void pushType(StackType Type) { Stack.push_back(Type); }
+  bool match(StackType TypeA, StackType TypeB);
+  std::string getTypesString(ArrayRef<StackType> Types, size_t StartPos);
+  SmallVector<StackType, 4>
+  valTypeToStackType(ArrayRef<wasm::ValType> ValTypes);
+
   void dumpTypeStack(Twine Msg);
   bool typeError(SMLoc ErrorLoc, const Twine &Msg);
-  bool popType(SMLoc ErrorLoc, std::optional<wasm::ValType> EVT);
-  bool popRefType(SMLoc ErrorLoc);
   bool getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp, wasm::ValType &Type);
   bool checkEnd(SMLoc ErrorLoc, bool PopVals = false);
   bool checkBr(SMLoc ErrorLoc, size_t Level);
@@ -59,7 +82,7 @@ class WebAssemblyAsmTypeCheck final {
   void funcDecl(const wasm::WasmSignature &Sig);
   void localDecl(const SmallVectorImpl<wasm::ValType> &Locals);
   void setLastSig(const wasm::WasmSignature &Sig) { LastSig = Sig; }
-  bool endOfFunction(SMLoc ErrorLoc);
+  bool endOfFunction(SMLoc ErrorLoc, bool ExactMatch);
   bool typeCheck(SMLoc ErrorLoc, const MCInst &Inst, OperandVector &Operands);
 
   void clear() {
diff --git a/llvm/test/MC/WebAssembly/basic-assembly.s b/llvm/test/MC/WebAssembly/basic-assembly.s
index db7ccc9759beca..6cca87d77c20f5 100644
--- a/llvm/test/MC/WebAssembly/basic-assembly.s
+++ b/llvm/test/MC/WebAssembly/basic-assembly.s
@@ -119,6 +119,13 @@ test0:
     #i32.trunc_sat_f32_s
     global.get  __stack_pointer
     global.set  __stack_pointer
+    # FIXME Currently block parameter and return types are not handled
+    # correctly, causing some types to remain on the stack. This test will be
+    # fixed to be valid with the follow-up PR. Until then, to suppress the
+    # return type error, we add some drops here.
+    drop
+    drop
+    drop
     end_function
 
     .section    .rodata..L.str,"",@
@@ -255,6 +262,9 @@ empty_exnref_table:
 # CHECK-NEXT:  .LBB0_4:
 # CHECK-NEXT:      global.get  __stack_pointer
 # CHECK-NEXT:      global.set  __stack_pointer
+# CHECK-NEXT:      drop
+# CHECK-NEXT:      drop
+# CHECK-NEXT:      drop
 # CHECK-NEXT:      end_function
 
 # CHECK:           .section    .rodata..L.str,"",@
diff --git a/llvm/test/MC/WebAssembly/type-checker-errors.s b/llvm/test/MC/WebAssembly/type-checker-errors.s
index 3106fe76c8449f..5fdc2f56daf57b 100644
--- a/llvm/test/MC/WebAssembly/type-checker-errors.s
+++ b/llvm/test/MC/WebAssembly/type-checker-errors.s
@@ -19,7 +19,7 @@ local_set_no_local_type:
 local_set_empty_stack_while_popping:
   .functype local_set_empty_stack_while_popping () -> ()
   .local i32
-# CHECK: [[@LINE+1]]:3: error: empty stack while popping i32
+# CHECK: [[@LINE+1]]:3: error: type mismatch, expected [i32] but got []
   local.set 0
   end_function
 
@@ -27,7 +27,7 @@ local_set_type_mismatch:
   .functype local_set_type_mismatch () -> ()
   .local i32
   f32.const 1.0
-# CHECK: [[@LINE+1]]:3: error: popped f32, expected i32
+# CHECK: [[@LINE+1]]:3: error: type mismatch, expected [i32] but got [f32]
   local.set 0
 ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/110094


More information about the llvm-commits mailing list