[llvm] 5495c36 - [WebAssembly] Misc. refactoring in AsmTypeCheck (NFC) (#107978)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 10 14:23:00 PDT 2024
Author: Heejin Ahn
Date: 2024-09-10T14:22:57-07:00
New Revision: 5495c36104103c4172808a28e8b2df3c806b1d85
URL: https://github.com/llvm/llvm-project/commit/5495c36104103c4172808a28e8b2df3c806b1d85
DIFF: https://github.com/llvm/llvm-project/commit/5495c36104103c4172808a28e8b2df3c806b1d85.diff
LOG: [WebAssembly] Misc. refactoring in AsmTypeCheck (NFC) (#107978)
Existing methods in AsmTypeCheck assumes symbol operand is the 0th
operand; they take a `MCInst` and take `getOperand(0)` on it. I think
passing a `MCOperand` removes this assumption and also is more
intuitive. This was motivated by a new `try_table` instruction, whose
support is going to be added to AsmTypeCheck soon, which has tag symbol
operands in any position, depending on the number and the kinds of catch
clauses. This PR changes all methods' signature that assumes the 0th
operand is the relevant one, even if it's not the symbol operand.
This also adds `getSignature` method, which factors out the common task
when getting a `WasmSignature` from a `MCOperand`.
Added:
Modified:
llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
index 9f9e7d1c0ed066..ec3d51d4e0e843 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
@@ -112,9 +112,9 @@ bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
return false;
}
-bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
wasm::ValType &Type) {
- auto Local = static_cast<size_t>(Inst.getOperand(0).getImm());
+ auto Local = static_cast<size_t>(LocalOp.getImm());
if (Local >= LocalTypes.size())
return typeError(ErrorLoc, StringRef("no local type specified for index ") +
std::to_string(Local));
@@ -178,21 +178,21 @@ bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
return false;
}
-bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCOperand &SymOp,
const MCSymbolRefExpr *&SymRef) {
- auto Op = Inst.getOperand(0);
- if (!Op.isExpr())
+ if (!SymOp.isExpr())
return typeError(ErrorLoc, StringRef("expected expression operand"));
- SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr());
+ SymRef = dyn_cast<MCSymbolRefExpr>(SymOp.getExpr());
if (!SymRef)
return typeError(ErrorLoc, StringRef("expected symbol operand"));
return false;
}
-bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc,
+ const MCOperand &GlobalOp,
wasm::ValType &Type) {
const MCSymbolRefExpr *SymRef;
- if (getSymRef(ErrorLoc, Inst, SymRef))
+ if (getSymRef(ErrorLoc, GlobalOp, SymRef))
return true;
auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) {
@@ -217,10 +217,10 @@ bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
return false;
}
-bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCOperand &TableOp,
wasm::ValType &Type) {
const MCSymbolRefExpr *SymRef;
- if (getSymRef(ErrorLoc, Inst, SymRef))
+ if (getSymRef(ErrorLoc, TableOp, SymRef))
return true;
auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) !=
@@ -231,6 +231,34 @@ bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
return false;
}
+bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc,
+ const MCOperand &SigOp,
+ wasm::WasmSymbolType Type,
+ const wasm::WasmSignature *&Sig) {
+ const MCSymbolRefExpr *SymRef = nullptr;
+ if (getSymRef(ErrorLoc, SigOp, SymRef))
+ return true;
+ const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
+ Sig = WasmSym->getSignature();
+
+ if (!Sig || WasmSym->getType() != Type) {
+ const char *TypeName = nullptr;
+ switch (Type) {
+ case wasm::WASM_SYMBOL_TYPE_FUNCTION:
+ TypeName = "func";
+ break;
+ case wasm::WASM_SYMBOL_TYPE_TAG:
+ TypeName = "tag";
+ break;
+ default:
+ return true;
+ }
+ return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
+ ": missing ." + TypeName + "type");
+ }
+ return false;
+}
+
bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
// Check the return types.
for (auto RVT : llvm::reverse(ReturnTypes)) {
@@ -252,48 +280,48 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
dumpTypeStack("typechecking " + Name + ": ");
wasm::ValType Type;
if (Name == "local.get") {
- if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
Stack.push_back(Type);
} else if (Name == "local.set") {
- if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
} else if (Name == "local.tee") {
- if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
Stack.push_back(Type);
} else if (Name == "global.get") {
- if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getGlobal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
Stack.push_back(Type);
} else if (Name == "global.set") {
- if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getGlobal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
} else if (Name == "table.get") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
Stack.push_back(Type);
} else if (Name == "table.set") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
} else if (Name == "table.size") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
Stack.push_back(wasm::ValType::I32);
} else if (Name == "table.grow") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
@@ -301,7 +329,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
return true;
Stack.push_back(wasm::ValType::I32);
} else if (Name == "table.fill") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
@@ -352,15 +380,10 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
return true;
Unreachable = false;
if (Name == "catch") {
- const MCSymbolRefExpr *SymRef;
- if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
+ const wasm::WasmSignature *Sig = nullptr;
+ if (getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
+ wasm::WASM_SYMBOL_TYPE_TAG, Sig))
return true;
- const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
- const auto *Sig = WasmSym->getSignature();
- if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG)
- return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
- WasmSym->getName() +
- ": missing .tagtype");
// catch instruction pushes values whose types are specified in the tag's
// "params" part
Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end());
@@ -383,15 +406,10 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
return true;
} else if (Name == "call" || Name == "return_call") {
- const MCSymbolRefExpr *SymRef;
- if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
- return true;
- auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
- auto Sig = WasmSym->getSignature();
- if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION)
- return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
- WasmSym->getName() +
- ": missing .functype");
+ const wasm::WasmSignature *Sig = nullptr;
+ if (getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
+ wasm::WASM_SYMBOL_TYPE_FUNCTION, Sig))
+ return true;
if (checkSig(ErrorLoc, *Sig))
return true;
if (Name == "return_call" && endOfFunction(ErrorLoc))
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
index 6fa95c3929753c..9ba5693719e91a 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
@@ -41,14 +41,17 @@ class WebAssemblyAsmTypeCheck final {
bool typeError(SMLoc ErrorLoc, const Twine &Msg);
bool popType(SMLoc ErrorLoc, std::optional<wasm::ValType> EVT);
bool popRefType(SMLoc ErrorLoc);
- bool getLocal(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
+ bool getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp, wasm::ValType &Type);
bool checkEnd(SMLoc ErrorLoc, bool PopVals = false);
bool checkBr(SMLoc ErrorLoc, size_t Level);
bool checkSig(SMLoc ErrorLoc, const wasm::WasmSignature &Sig);
- bool getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
+ bool getSymRef(SMLoc ErrorLoc, const MCOperand &SymOp,
const MCSymbolRefExpr *&SymRef);
- bool getGlobal(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
- bool getTable(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
+ bool getGlobal(SMLoc ErrorLoc, const MCOperand &GlobalOp,
+ wasm::ValType &Type);
+ bool getTable(SMLoc ErrorLoc, const MCOperand &TableOp, wasm::ValType &Type);
+ bool getSignature(SMLoc ErrorLoc, const MCOperand &SigOp,
+ wasm::WasmSymbolType Type, const wasm::WasmSignature *&Sig);
public:
WebAssemblyAsmTypeCheck(MCAsmParser &Parser, const MCInstrInfo &MII,
More information about the llvm-commits
mailing list