[llvm] [WebAssembly] Misc. refactoring in AsmTypeCheck (NFC) (PR #107978)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 10 00:50:49 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-webassembly
Author: Heejin Ahn (aheejin)
<details>
<summary>Changes</summary>
Existing methods in AsmTypeCheck assumes symbol operand is the 0th operand; they take a `MCInst` and take `getOperand(0)` on it. I think passing a `MCOperand` removes this assumption and also is more intuitive. This was motivated by a new `try_table` instruction, whose support is going to be added to AsmTypeCheck soon, which has tag symbol operands in any position, depending on the number and the kinds of catch clauses. This PR changes all methods' signature that assumes the 0th operand is the relevant one, even if it's not the symbol operand.
This also adds `getSignature` method, which factors out the common task when getting a `WasmSignature` from a `MCOperand`.
---
Full diff: https://github.com/llvm/llvm-project/pull/107978.diff
2 Files Affected:
- (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp (+51-35)
- (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h (+6-4)
``````````diff
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
index f81f4556a00a14..8075bb90b75ad7 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
@@ -112,9 +112,9 @@ bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
return false;
}
-bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &Op,
wasm::ValType &Type) {
- auto Local = static_cast<size_t>(Inst.getOperand(0).getImm());
+ auto Local = static_cast<size_t>(Op.getImm());
if (Local >= LocalTypes.size())
return typeError(ErrorLoc, StringRef("no local type specified for index ") +
std::to_string(Local));
@@ -178,9 +178,8 @@ bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
return false;
}
-bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCOperand &Op,
const MCSymbolRefExpr *&SymRef) {
- auto Op = Inst.getOperand(0);
if (!Op.isExpr())
return typeError(ErrorLoc, StringRef("expected expression operand"));
SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr());
@@ -189,10 +188,10 @@ bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
return false;
}
-bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCOperand &Op,
wasm::ValType &Type) {
const MCSymbolRefExpr *SymRef;
- if (getSymRef(ErrorLoc, Inst, SymRef))
+ if (getSymRef(ErrorLoc, Op, SymRef))
return true;
auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) {
@@ -217,10 +216,10 @@ bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
return false;
}
-bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCOperand &Op,
wasm::ValType &Type) {
const MCSymbolRefExpr *SymRef;
- if (getSymRef(ErrorLoc, Inst, SymRef))
+ if (getSymRef(ErrorLoc, Op, SymRef))
return true;
auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) !=
@@ -245,6 +244,33 @@ bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
return false;
}
+bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc, const MCOperand &Op,
+ wasm::WasmSymbolType Type,
+ const wasm::WasmSignature *&Sig) {
+ const MCSymbolRefExpr *SymRef = nullptr;
+ if (getSymRef(ErrorLoc, Op, SymRef))
+ return true;
+ const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
+ Sig = WasmSym->getSignature();
+
+ if (!Sig || WasmSym->getType() != Type) {
+ const char *TypeName = nullptr;
+ switch (Type) {
+ case wasm::WASM_SYMBOL_TYPE_FUNCTION:
+ TypeName = "func";
+ break;
+ case wasm::WASM_SYMBOL_TYPE_TAG:
+ TypeName = "tag";
+ break;
+ default:
+ return true;
+ }
+ return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
+ " missing ." + TypeName + "type");
+ }
+ return false;
+}
+
bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
OperandVector &Operands) {
auto Opc = Inst.getOpcode();
@@ -252,48 +278,48 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
dumpTypeStack("typechecking " + Name + ": ");
wasm::ValType Type;
if (Name == "local.get") {
- if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
Stack.push_back(Type);
} else if (Name == "local.set") {
- if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
} else if (Name == "local.tee") {
- if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
Stack.push_back(Type);
} else if (Name == "global.get") {
- if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getGlobal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
Stack.push_back(Type);
} else if (Name == "global.set") {
- if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getGlobal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
} else if (Name == "table.get") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
Stack.push_back(Type);
} else if (Name == "table.set") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
} else if (Name == "table.size") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
Stack.push_back(wasm::ValType::I32);
} else if (Name == "table.grow") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
@@ -301,7 +327,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
return true;
Stack.push_back(wasm::ValType::I32);
} else if (Name == "table.fill") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
@@ -352,15 +378,10 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
return true;
Unreachable = false;
if (Name == "catch") {
- const MCSymbolRefExpr *SymRef;
- if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
+ const wasm::WasmSignature *Sig = nullptr;
+ if (getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
+ wasm::WASM_SYMBOL_TYPE_TAG, Sig))
return true;
- const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
- const auto *Sig = WasmSym->getSignature();
- if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG)
- return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
- WasmSym->getName() +
- " missing .tagtype");
// catch instruction pushes values whose types are specified in the tag's
// "params" part
Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end());
@@ -383,15 +404,10 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
return true;
} else if (Name == "call" || Name == "return_call") {
- const MCSymbolRefExpr *SymRef;
- if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
- return true;
- auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
- auto Sig = WasmSym->getSignature();
- if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION)
- return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
- WasmSym->getName() +
- " missing .functype");
+ const wasm::WasmSignature *Sig = nullptr;
+ if (getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
+ wasm::WASM_SYMBOL_TYPE_FUNCTION, Sig))
+ return true;
if (checkSig(ErrorLoc, *Sig))
return true;
if (Name == "return_call" && endOfFunction(ErrorLoc))
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
index 6fa95c3929753c..38638cea22b4ba 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
@@ -41,14 +41,16 @@ class WebAssemblyAsmTypeCheck final {
bool typeError(SMLoc ErrorLoc, const Twine &Msg);
bool popType(SMLoc ErrorLoc, std::optional<wasm::ValType> EVT);
bool popRefType(SMLoc ErrorLoc);
- bool getLocal(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
+ bool getLocal(SMLoc ErrorLoc, const MCOperand &Op, wasm::ValType &Type);
bool checkEnd(SMLoc ErrorLoc, bool PopVals = false);
bool checkBr(SMLoc ErrorLoc, size_t Level);
bool checkSig(SMLoc ErrorLoc, const wasm::WasmSignature &Sig);
- bool getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
+ bool getSymRef(SMLoc ErrorLoc, const MCOperand &Op,
const MCSymbolRefExpr *&SymRef);
- bool getGlobal(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
- bool getTable(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
+ bool getGlobal(SMLoc ErrorLoc, const MCOperand &Op, wasm::ValType &Type);
+ bool getTable(SMLoc ErrorLoc, const MCOperand &Op, wasm::ValType &Type);
+ bool getSignature(SMLoc ErrorLoc, const MCOperand &Op,
+ wasm::WasmSymbolType Type, const wasm::WasmSignature *&Sig);
public:
WebAssemblyAsmTypeCheck(MCAsmParser &Parser, const MCInstrInfo &MII,
``````````
</details>
https://github.com/llvm/llvm-project/pull/107978
More information about the llvm-commits
mailing list