[llvm] [WebAssembly] Misc. refactoring in AsmTypeCheck (NFC) (PR #107978)
Heejin Ahn via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 10 11:11:13 PDT 2024
https://github.com/aheejin updated https://github.com/llvm/llvm-project/pull/107978
>From 044069b3b825bedc238d2bf1ea7b7d5be1ececa4 Mon Sep 17 00:00:00 2001
From: Heejin Ahn <aheejin at gmail.com>
Date: Tue, 10 Sep 2024 05:58:12 +0000
Subject: [PATCH 1/2] [WebAssembly] Misc. refactoring in AsmTypeCheck (NFC)
Existing methods in AsmTypeCheck assumes symbol operand is the 0th
operand; they take a `MCInst` and take `getOperand(0)` on it. I think
passing a `MCOperand` removes this assumption and also is more
intuitive. This was motivated by a new `try_table` instruction, whose
support is going to be added to AsmTypeCheck soon, which has tag symbol
operands in any position, depending on the number and the kinds of catch
clauses. This PR changes all methods' signature that assumes the 0th
operand is the relevant one, even if it's not the symbol operand.
This also adds `getSignature` method, which factors out the common task
when getting a `WasmSignature` from a `MCOperand`.
---
.../AsmParser/WebAssemblyAsmTypeCheck.cpp | 86 +++++++++++--------
.../AsmParser/WebAssemblyAsmTypeCheck.h | 10 ++-
2 files changed, 57 insertions(+), 39 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
index f81f4556a00a14..8075bb90b75ad7 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
@@ -112,9 +112,9 @@ bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
return false;
}
-bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &Op,
wasm::ValType &Type) {
- auto Local = static_cast<size_t>(Inst.getOperand(0).getImm());
+ auto Local = static_cast<size_t>(Op.getImm());
if (Local >= LocalTypes.size())
return typeError(ErrorLoc, StringRef("no local type specified for index ") +
std::to_string(Local));
@@ -178,9 +178,8 @@ bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
return false;
}
-bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCOperand &Op,
const MCSymbolRefExpr *&SymRef) {
- auto Op = Inst.getOperand(0);
if (!Op.isExpr())
return typeError(ErrorLoc, StringRef("expected expression operand"));
SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr());
@@ -189,10 +188,10 @@ bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
return false;
}
-bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCOperand &Op,
wasm::ValType &Type) {
const MCSymbolRefExpr *SymRef;
- if (getSymRef(ErrorLoc, Inst, SymRef))
+ if (getSymRef(ErrorLoc, Op, SymRef))
return true;
auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) {
@@ -217,10 +216,10 @@ bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
return false;
}
-bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
+bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCOperand &Op,
wasm::ValType &Type) {
const MCSymbolRefExpr *SymRef;
- if (getSymRef(ErrorLoc, Inst, SymRef))
+ if (getSymRef(ErrorLoc, Op, SymRef))
return true;
auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) !=
@@ -245,6 +244,33 @@ bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
return false;
}
+bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc, const MCOperand &Op,
+ wasm::WasmSymbolType Type,
+ const wasm::WasmSignature *&Sig) {
+ const MCSymbolRefExpr *SymRef = nullptr;
+ if (getSymRef(ErrorLoc, Op, SymRef))
+ return true;
+ const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
+ Sig = WasmSym->getSignature();
+
+ if (!Sig || WasmSym->getType() != Type) {
+ const char *TypeName = nullptr;
+ switch (Type) {
+ case wasm::WASM_SYMBOL_TYPE_FUNCTION:
+ TypeName = "func";
+ break;
+ case wasm::WASM_SYMBOL_TYPE_TAG:
+ TypeName = "tag";
+ break;
+ default:
+ return true;
+ }
+ return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
+ " missing ." + TypeName + "type");
+ }
+ return false;
+}
+
bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
OperandVector &Operands) {
auto Opc = Inst.getOpcode();
@@ -252,48 +278,48 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
dumpTypeStack("typechecking " + Name + ": ");
wasm::ValType Type;
if (Name == "local.get") {
- if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
Stack.push_back(Type);
} else if (Name == "local.set") {
- if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
} else if (Name == "local.tee") {
- if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
Stack.push_back(Type);
} else if (Name == "global.get") {
- if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getGlobal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
Stack.push_back(Type);
} else if (Name == "global.set") {
- if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
+ if (getGlobal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
} else if (Name == "table.get") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
Stack.push_back(Type);
} else if (Name == "table.set") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
} else if (Name == "table.size") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
Stack.push_back(wasm::ValType::I32);
} else if (Name == "table.grow") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
@@ -301,7 +327,7 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
return true;
Stack.push_back(wasm::ValType::I32);
} else if (Name == "table.fill") {
- if (getTable(Operands[1]->getStartLoc(), Inst, Type))
+ if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
return true;
if (popType(ErrorLoc, wasm::ValType::I32))
return true;
@@ -352,15 +378,10 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
return true;
Unreachable = false;
if (Name == "catch") {
- const MCSymbolRefExpr *SymRef;
- if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
+ const wasm::WasmSignature *Sig = nullptr;
+ if (getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
+ wasm::WASM_SYMBOL_TYPE_TAG, Sig))
return true;
- const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
- const auto *Sig = WasmSym->getSignature();
- if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG)
- return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
- WasmSym->getName() +
- " missing .tagtype");
// catch instruction pushes values whose types are specified in the tag's
// "params" part
Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end());
@@ -383,15 +404,10 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
return true;
} else if (Name == "call" || Name == "return_call") {
- const MCSymbolRefExpr *SymRef;
- if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
- return true;
- auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
- auto Sig = WasmSym->getSignature();
- if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION)
- return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
- WasmSym->getName() +
- " missing .functype");
+ const wasm::WasmSignature *Sig = nullptr;
+ if (getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
+ wasm::WASM_SYMBOL_TYPE_FUNCTION, Sig))
+ return true;
if (checkSig(ErrorLoc, *Sig))
return true;
if (Name == "return_call" && endOfFunction(ErrorLoc))
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
index 6fa95c3929753c..38638cea22b4ba 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
@@ -41,14 +41,16 @@ class WebAssemblyAsmTypeCheck final {
bool typeError(SMLoc ErrorLoc, const Twine &Msg);
bool popType(SMLoc ErrorLoc, std::optional<wasm::ValType> EVT);
bool popRefType(SMLoc ErrorLoc);
- bool getLocal(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
+ bool getLocal(SMLoc ErrorLoc, const MCOperand &Op, wasm::ValType &Type);
bool checkEnd(SMLoc ErrorLoc, bool PopVals = false);
bool checkBr(SMLoc ErrorLoc, size_t Level);
bool checkSig(SMLoc ErrorLoc, const wasm::WasmSignature &Sig);
- bool getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
+ bool getSymRef(SMLoc ErrorLoc, const MCOperand &Op,
const MCSymbolRefExpr *&SymRef);
- bool getGlobal(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
- bool getTable(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
+ bool getGlobal(SMLoc ErrorLoc, const MCOperand &Op, wasm::ValType &Type);
+ bool getTable(SMLoc ErrorLoc, const MCOperand &Op, wasm::ValType &Type);
+ bool getSignature(SMLoc ErrorLoc, const MCOperand &Op,
+ wasm::WasmSymbolType Type, const wasm::WasmSignature *&Sig);
public:
WebAssemblyAsmTypeCheck(MCAsmParser &Parser, const MCInstrInfo &MII,
>From a2d3dfb928709d9e849872910564242e5ecca61d Mon Sep 17 00:00:00 2001
From: Heejin Ahn <aheejin at gmail.com>
Date: Tue, 10 Sep 2024 18:08:06 +0000
Subject: [PATCH 2/2] Change 'Op' to more descriptive names
---
.../AsmParser/WebAssemblyAsmTypeCheck.cpp | 52 ++++++++++---------
.../AsmParser/WebAssemblyAsmTypeCheck.h | 11 ++--
2 files changed, 33 insertions(+), 30 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
index 8075bb90b75ad7..d0245d656c7a8b 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
@@ -112,9 +112,9 @@ bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
return false;
}
-bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &Op,
+bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
wasm::ValType &Type) {
- auto Local = static_cast<size_t>(Op.getImm());
+ auto Local = static_cast<size_t>(LocalOp.getImm());
if (Local >= LocalTypes.size())
return typeError(ErrorLoc, StringRef("no local type specified for index ") +
std::to_string(Local));
@@ -178,20 +178,21 @@ bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
return false;
}
-bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCOperand &Op,
+bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCOperand &SymOp,
const MCSymbolRefExpr *&SymRef) {
- if (!Op.isExpr())
+ if (!SymOp.isExpr())
return typeError(ErrorLoc, StringRef("expected expression operand"));
- SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr());
+ SymRef = dyn_cast<MCSymbolRefExpr>(SymOp.getExpr());
if (!SymRef)
return typeError(ErrorLoc, StringRef("expected symbol operand"));
return false;
}
-bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCOperand &Op,
+bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc,
+ const MCOperand &GlobalOp,
wasm::ValType &Type) {
const MCSymbolRefExpr *SymRef;
- if (getSymRef(ErrorLoc, Op, SymRef))
+ if (getSymRef(ErrorLoc, GlobalOp, SymRef))
return true;
auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) {
@@ -216,10 +217,10 @@ bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCOperand &Op,
return false;
}
-bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCOperand &Op,
+bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCOperand &TableOp,
wasm::ValType &Type) {
const MCSymbolRefExpr *SymRef;
- if (getSymRef(ErrorLoc, Op, SymRef))
+ if (getSymRef(ErrorLoc, TableOp, SymRef))
return true;
auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) !=
@@ -230,25 +231,12 @@ bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCOperand &Op,
return false;
}
-bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
- // Check the return types.
- for (auto RVT : llvm::reverse(ReturnTypes)) {
- if (popType(ErrorLoc, RVT))
- return true;
- }
- if (!Stack.empty()) {
- return typeError(ErrorLoc, std::to_string(Stack.size()) +
- " superfluous return values");
- }
- Unreachable = true;
- return false;
-}
-
-bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc, const MCOperand &Op,
+bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc,
+ const MCOperand &SigOp,
wasm::WasmSymbolType Type,
const wasm::WasmSignature *&Sig) {
const MCSymbolRefExpr *SymRef = nullptr;
- if (getSymRef(ErrorLoc, Op, SymRef))
+ if (getSymRef(ErrorLoc, SigOp, SymRef))
return true;
const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
Sig = WasmSym->getSignature();
@@ -271,6 +259,20 @@ bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc, const MCOperand &Op,
return false;
}
+bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
+ // Check the return types.
+ for (auto RVT : llvm::reverse(ReturnTypes)) {
+ if (popType(ErrorLoc, RVT))
+ return true;
+ }
+ if (!Stack.empty()) {
+ return typeError(ErrorLoc, std::to_string(Stack.size()) +
+ " superfluous return values");
+ }
+ Unreachable = true;
+ return false;
+}
+
bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
OperandVector &Operands) {
auto Opc = Inst.getOpcode();
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
index 38638cea22b4ba..9ba5693719e91a 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
@@ -41,15 +41,16 @@ class WebAssemblyAsmTypeCheck final {
bool typeError(SMLoc ErrorLoc, const Twine &Msg);
bool popType(SMLoc ErrorLoc, std::optional<wasm::ValType> EVT);
bool popRefType(SMLoc ErrorLoc);
- bool getLocal(SMLoc ErrorLoc, const MCOperand &Op, wasm::ValType &Type);
+ bool getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp, wasm::ValType &Type);
bool checkEnd(SMLoc ErrorLoc, bool PopVals = false);
bool checkBr(SMLoc ErrorLoc, size_t Level);
bool checkSig(SMLoc ErrorLoc, const wasm::WasmSignature &Sig);
- bool getSymRef(SMLoc ErrorLoc, const MCOperand &Op,
+ bool getSymRef(SMLoc ErrorLoc, const MCOperand &SymOp,
const MCSymbolRefExpr *&SymRef);
- bool getGlobal(SMLoc ErrorLoc, const MCOperand &Op, wasm::ValType &Type);
- bool getTable(SMLoc ErrorLoc, const MCOperand &Op, wasm::ValType &Type);
- bool getSignature(SMLoc ErrorLoc, const MCOperand &Op,
+ bool getGlobal(SMLoc ErrorLoc, const MCOperand &GlobalOp,
+ wasm::ValType &Type);
+ bool getTable(SMLoc ErrorLoc, const MCOperand &TableOp, wasm::ValType &Type);
+ bool getSignature(SMLoc ErrorLoc, const MCOperand &SigOp,
wasm::WasmSymbolType Type, const wasm::WasmSignature *&Sig);
public:
More information about the llvm-commits
mailing list