[Mlir-commits] [mlir] [MLIR][WasmSSA] Instruction parser refactoring of WasmSSA importer (PR #195500)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat May 2 22:33:36 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Luc Forget (lforg37)
<details>
<summary>Changes</summary>
Refactored WasmSSA importer mechanism to dispatch control flow to relevant parser based on op code.
This is to prepare for instructions with multi-bytes opcodes (e.g. vector instruction and some scalar instructions extensions) which will be able to reuse the same mechanism.
It also replaces the bit tree to find the address by a jump table.
---
Patch is 25.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/195500.diff
1 Files Affected:
- (modified) mlir/lib/Target/Wasm/TranslateFromWasm.cpp (+225-130)
``````````diff
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 048e964037558..e90d627bae151 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -237,6 +237,15 @@ class ValueStack {
using local_val_t = TypedValue<wasmssa::LocalRefType>;
+template <typename TargetType, size_t... IS>
+constexpr std::integer_sequence<TargetType, TargetType{IS}...>
+castIndexSequence(std::index_sequence<IS...>) {
+ return {};
+}
+
+constexpr auto all8bitsBytes =
+ castIndexSequence<std::byte>(std::make_index_sequence<256>());
+
class ExpressionParser {
public:
using locals_t = SmallVector<local_val_t>;
@@ -245,9 +254,6 @@ class ExpressionParser {
: parser{parser}, symbols{symbols}, locals{initLocal} {}
private:
- template <std::byte opCode>
- inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder);
-
template <typename valueT>
parsed_inst_t
parseConstInst(OpBuilder &builder,
@@ -280,38 +286,8 @@ class ExpressionParser {
typename... extraArgsT>
inline parsed_inst_t buildConvertOp(OpBuilder &builder, extraArgsT...);
- /// This function generates a dispatch tree to associate an opcode with a
- /// parser. Parsers are registered by specialising the
- /// `parseSpecificInstruction` function for the op code to handle.
- ///
- /// The dispatcher is generated by recursively creating all possible patterns
- /// for an opcode and calling the relevant parser on the leaf.
- ///
- /// @tparam patternBitSize is the first bit for which the pattern is not fixed
- ///
- /// @tparam highBitPattern is the fixed pattern that this instance handles for
- /// the 8-patternBitSize bits
- template <size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}>
inline parsed_inst_t dispatchToInstParser(std::byte opCode,
- OpBuilder &builder) {
- static_assert(patternBitSize <= 8,
- "PatternBitSize is outside of range of opcode space! "
- "(expected at most 8 bits)");
- if constexpr (patternBitSize < 8) {
- constexpr std::byte bitSelect{1 << (7 - patternBitSize)};
- constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
- constexpr size_t nextPatternBitSize = patternBitSize + 1;
- if ((opCode & bitSelect) != std::byte{0})
- return dispatchToInstParser<nextPatternBitSize,
- nextHighBitPatternStem | std::byte{1}>(
- opCode, builder);
- return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
- opCode, builder);
- } else {
- return parseSpecificInstruction<highBitPattern>(builder);
- }
- }
-
+ OpBuilder &builder);
///
/// RAII guard class for creating a nesting level
///
@@ -444,6 +420,15 @@ class ExpressionParser {
template <typename OpToCreate>
parsed_inst_t parseBlockLikeOp(OpBuilder &);
+ std::optional<Location> getCurrentOpLoc() { return currentOpLoc; }
+
+ class TopLevelInstParserRegistry {
+ public:
+ template <std::byte opCode>
+ static parsed_inst_t parseInstrWithOpCode(OpBuilder &,
+ ExpressionParser &) = delete;
+ };
+
private:
std::optional<Location> currentOpLoc;
ParserHead &parser;
@@ -452,6 +437,79 @@ class ExpressionParser {
ValueStack valueStack;
};
+static inline parsed_inst_t
+unreachableHandler(OpBuilder &, ExpressionParser &expressionParser) {
+ llvm_unreachable("Failure in opcode parser dispatch logic.");
+ return mlir::failure();
+}
+
+template <typename ParserRegistry>
+class InstDispatcher {
+private:
+ /// used to check if a parser is registered in the parser registry for a
+ /// given opcode.
+ /// Default to unreachable handler.
+ template <std::byte opCode, typename = void>
+ struct ParserInfo : std::false_type {
+ static constexpr auto handler = unreachableHandler;
+ };
+ template <std::byte opCode>
+ struct ParserInfo<
+ opCode,
+ std::void_t<
+ decltype(&ParserRegistry::template parseInstrWithOpCode<opCode>)>>
+ : std::true_type {
+ static constexpr auto handler =
+ ParserRegistry::template parseInstrWithOpCode<opCode>;
+ };
+
+public:
+ template <std::byte opCode>
+ static constexpr bool isValidInst = ParserInfo<opCode>::value;
+
+private:
+ static inline parsed_inst_t
+ invalidOpcodeDiag(OpBuilder &, ExpressionParser &expressionParser,
+ std::byte opCode) {
+ return emitError(*(expressionParser.getCurrentOpLoc()),
+ "unknown instruction opcode: ")
+ << static_cast<int>(opCode);
+ }
+
+ using dispatch_t = parsed_inst_t (*)(OpBuilder &, ExpressionParser &);
+
+ template <std::byte... opCodes>
+ static inline parsed_inst_t
+ dispatchImpl(std::byte opCode, OpBuilder &builder,
+ ExpressionParser &exprParser,
+ std::integer_sequence<std::byte, opCodes...>) {
+ static constexpr std::array<bool, 256> opcodeValidityMap{
+ isValidInst<opCodes>...};
+ static constexpr std::array<dispatch_t, 256> dispatchTable{
+ ParserInfo<opCodes>::handler...};
+ if (opcodeValidityMap[static_cast<size_t>(opCode)]) {
+ return dispatchTable[static_cast<size_t>(opCode)](builder, exprParser);
+ }
+ return invalidOpcodeDiag(builder, exprParser, opCode);
+ }
+
+public:
+ ///
+ /// @brief dispatch control flow to the sub parser registered for opCode in
+ /// `ParserRegistry`
+ ///
+ /// @param opCode opCode of the instruction to be Parsed
+ /// @param builder builder that will be passed to the parser
+ /// @param exprParser the generic parser passed to the sub parser
+ ///
+ /// @return the result of the parser or an error if there is no parser
+ /// registered for the opcode (emits a diagnostic)
+ static parsed_inst_t dispatch(std::byte opCode, OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ return dispatchImpl(opCode, builder, exprParser, all8bitsBytes);
+ }
+};
+
class ParserHead {
public:
ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {}
@@ -850,12 +908,6 @@ inline FailureOr<int64_t> ParserHead::parseI64() {
return parseLiteral<int64_t>();
}
-template <std::byte opCode>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
- return emitError(*currentOpLoc, "unknown instruction opcode: ")
- << static_cast<int>(opCode);
-}
-
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void ValueStack::dump() const {
llvm::dbgs() << "================= Wasm ValueStack =======================\n";
@@ -1010,32 +1062,36 @@ parsed_inst_t ExpressionParser::parseBlockLikeOp(OpBuilder &builder) {
template <>
inline parsed_inst_t
-ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::block>(
- OpBuilder &builder) {
- return parseBlockLikeOp<BlockOp>(builder);
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::block>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ return exprParser.parseBlockLikeOp<BlockOp>(builder);
}
template <>
inline parsed_inst_t
-ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::loop>(
- OpBuilder &builder) {
- return parseBlockLikeOp<LoopOp>(builder);
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::loop>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ return exprParser.parseBlockLikeOp<LoopOp>(builder);
}
template <>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
- WasmBinaryEncoding::OpCode::ifOpCode>(OpBuilder &builder) {
- auto opLoc = currentOpLoc;
- auto funcType = parseBlockFuncType(builder);
+inline parsed_inst_t
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::ifOpCode>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ auto opLoc = exprParser.currentOpLoc;
+ auto funcType = exprParser.parseBlockFuncType(builder);
if (failed(funcType))
return failure();
LDBG() << "Parsing an if instruction of type " << *funcType;
auto inputTypes = funcType->getInputs();
- auto conditionValue = popOperands(builder.getI32Type());
+ auto conditionValue = exprParser.popOperands(builder.getI32Type());
if (failed(conditionValue))
return failure();
- auto inputOps = popOperands(inputTypes);
+ auto inputOps = exprParser.popOperands(inputTypes);
if (failed(inputOps))
return failure();
@@ -1043,25 +1099,25 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
Region *curRegion = curBlock->getParent();
auto resTypes = funcType->getResults();
llvm::SmallVector<Location> locations{};
- locations.resize(resTypes.size(), *currentOpLoc);
+ locations.resize(resTypes.size(), *exprParser.getCurrentOpLoc());
auto *successor =
builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
builder.setInsertionPointToEnd(curBlock);
- auto ifOp = IfOp::create(builder, *currentOpLoc, conditionValue->front(),
- *inputOps, successor);
+ auto ifOp = IfOp::create(builder, *exprParser.getCurrentOpLoc(),
+ conditionValue->front(), *inputOps, successor);
auto *ifEntryBlock = ifOp.createIfBlock();
constexpr auto ifElseFilter =
ByteSequence<WasmBinaryEncoding::endByte,
WasmBinaryEncoding::OpCode::elseOpCode>{};
- auto parseIfRes = parseBlockContent(builder, ifEntryBlock, resTypes, *opLoc,
- ifOp, ifElseFilter);
+ auto parseIfRes = exprParser.parseBlockContent(
+ builder, ifEntryBlock, resTypes, *opLoc, ifOp, ifElseFilter);
if (failed(parseIfRes))
return failure();
if (*parseIfRes == WasmBinaryEncoding::OpCode::elseOpCode) {
LDBG() << " else block is present.";
Block *elseEntryBlock = ifOp.createElseBlock();
- auto parseElseRes =
- parseBlockContent(builder, elseEntryBlock, resTypes, *opLoc, ifOp);
+ auto parseElseRes = exprParser.parseBlockContent(builder, elseEntryBlock,
+ resTypes, *opLoc, ifOp);
if (failed(parseElseRes))
return failure();
}
@@ -1070,16 +1126,18 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
}
template <>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
- WasmBinaryEncoding::OpCode::branchIf>(OpBuilder &builder) {
- auto level = parser.parseLiteral<uint32_t>();
+inline parsed_inst_t
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::branchIf>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ auto level = exprParser.parser.parseLiteral<uint32_t>();
if (failed(level))
return failure();
Block *curBlock = builder.getBlock();
Region *curRegion = curBlock->getParent();
auto sip = builder.saveInsertionPoint();
Block *elseBlock = builder.createBlock(curRegion, curRegion->end());
- auto condition = popOperands(builder.getI32Type());
+ auto condition = exprParser.popOperands(builder.getI32Type());
if (failed(condition))
return failure();
builder.restoreInsertionPoint(sip);
@@ -1088,10 +1146,10 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
if (failed(targetOp))
return failure();
auto inputTypes = targetOp->getLabelTarget()->getArgumentTypes();
- auto branchArgs = popOperands(inputTypes);
+ auto branchArgs = exprParser.popOperands(inputTypes);
if (failed(branchArgs))
return failure();
- BranchIfOp::create(builder, *currentOpLoc, condition->front(),
+ BranchIfOp::create(builder, *exprParser.getCurrentOpLoc(), condition->front(),
builder.getUI32IntegerAttr(*level), *branchArgs,
elseBlock);
builder.setInsertionPointToStart(elseBlock);
@@ -1100,18 +1158,19 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
template <>
inline parsed_inst_t
-ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>(
- OpBuilder &builder) {
- auto loc = *currentOpLoc;
- auto funcIdx = parser.parseLiteral<uint32_t>();
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::call>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ auto loc = *exprParser.currentOpLoc;
+ auto funcIdx = exprParser.parser.parseLiteral<uint32_t>();
if (failed(funcIdx))
return failure();
- if (*funcIdx >= symbols.funcSymbols.size())
+ if (*funcIdx >= exprParser.symbols.funcSymbols.size())
return emitError(loc, "Invalid function index: ") << *funcIdx;
- auto callee = symbols.funcSymbols[*funcIdx];
+ auto callee = exprParser.symbols.funcSymbols[*funcIdx];
llvm::ArrayRef<Type> inTypes = callee.functionType.getInputs();
llvm::ArrayRef<Type> resTypes = callee.functionType.getResults();
- parsed_inst_t inOperands = popOperands(inTypes);
+ parsed_inst_t inOperands = exprParser.popOperands(inTypes);
if (failed(inOperands))
return failure();
auto callOp =
@@ -1120,30 +1179,36 @@ ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>(
}
template <>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
- WasmBinaryEncoding::OpCode::localGet>(OpBuilder &builder) {
- FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
- Location instLoc = *currentOpLoc;
+inline parsed_inst_t
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::localGet>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ FailureOr<uint32_t> id = exprParser.parser.parseLiteral<uint32_t>();
+ Location instLoc = *exprParser.currentOpLoc;
if (failed(id))
return failure();
- if (*id >= locals.size())
+ if (*id >= exprParser.locals.size())
return emitError(instLoc, "invalid local index. function has ")
- << locals.size() << " accessible locals, received index " << *id;
- return {{LocalGetOp::create(builder, instLoc, locals[*id]).getResult()}};
+ << exprParser.locals.size() << " accessible locals, received index "
+ << *id;
+ return {{LocalGetOp::create(builder, instLoc, exprParser.locals[*id])
+ .getResult()}};
}
template <>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
- WasmBinaryEncoding::OpCode::globalGet>(OpBuilder &builder) {
- FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
- Location instLoc = *currentOpLoc;
+inline parsed_inst_t
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::globalGet>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ FailureOr<uint32_t> id = exprParser.parser.parseLiteral<uint32_t>();
+ Location instLoc = *exprParser.currentOpLoc;
if (failed(id))
return failure();
- if (*id >= symbols.globalSymbols.size())
+ if (*id >= exprParser.symbols.globalSymbols.size())
return emitError(instLoc, "invalid global index. function has ")
- << symbols.globalSymbols.size()
+ << exprParser.symbols.globalSymbols.size()
<< " accessible globals, received index " << *id;
- GlobalSymbolRefContainer globalVar = symbols.globalSymbols[*id];
+ GlobalSymbolRefContainer globalVar = exprParser.symbols.globalSymbols[*id];
auto globalOp = GlobalGetOp::create(builder, instLoc, globalVar.globalType,
globalVar.symbol);
@@ -1172,15 +1237,19 @@ parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) {
}
template <>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
- WasmBinaryEncoding::OpCode::localSet>(OpBuilder &builder) {
- return parseSetOrTee<LocalSetOp>(builder);
+inline parsed_inst_t
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::localSet>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ return exprParser.parseSetOrTee<LocalSetOp>(builder);
}
template <>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
- WasmBinaryEncoding::OpCode::localTee>(OpBuilder &builder) {
- return parseSetOrTee<LocalTeeOp>(builder);
+inline parsed_inst_t
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::localTee>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ return exprParser.parseSetOrTee<LocalTeeOp>(builder);
}
template <typename T>
@@ -1252,27 +1321,35 @@ parsed_inst_t ExpressionParser::parseConstInst(
}
template <>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
- WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) {
- return parseConstInst<int32_t>(builder);
+inline parsed_inst_t
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ return exprParser.parseConstInst<int32_t>(builder);
}
template <>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
- WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) {
- return parseConstInst<int64_t>(builder);
+inline parsed_inst_t
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ return exprParser.parseConstInst<int64_t>(builder);
}
template <>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
- WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) {
- return parseConstInst<float>(builder);
+inline parsed_inst_t
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ return exprParser.parseConstInst<float>(builder);
}
template <>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
- WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) {
- return parseConstInst<double>(builder);
+inline parsed_inst_t
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+ WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder,
+ ExpressionParser &exprParser) {
+ return exprParser.parseConstInst<double>(builder);
}
template <typename opcode, typename valueType, unsigned int numOperands>
@@ -1295,9 +1372,10 @@ inline parsed_inst_t ExpressionParser::buildNumericOp(
// Convenience macro for generating numerical operations.
#define BUILD_NUMERIC_OP(OP_NAME, N_ARGS, PREFIX, SUFFIX, TYPE) \
template <> \
- inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
- WasmBinaryEncoding::OpCode::PREFIX##SUFFIX>(OpBuilder & builder) { \
- return buildNumericOp<OP_NAME, TYPE, N_ARGS>(builder); \
+ inline parsed_inst_t ExpressionParser::TopLevelInstParserRegistry:: \
+ parseInstrWithOpCode<WasmBinaryEncoding::OpCode::PREFIX##SUFFIX>( \
+ OpBuilder & builder, ExpressionParser & exprParser) { \
+ return exprParser.buildNumericOp<OP_NAME, TYPE, N_ARGS>(builder); \
}
// Macro to define binops that only support integer types.
@@ -1398,23 +1476,27 @@ inline parsed_inst_t ExpressionParser::buildConvertOp(OpBuilder &builder,
}
template <>
-inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
- WasmBinaryEncoding::OpCode::demoteF64ToF32>(OpBuilder &builder) {
- return buildConvertOp<DemoteOp, double, float>(builder);
+inline parsed_inst_t
+ExpressionParser::TopLevelInstParserRegistry::parseInstrWithOpCode<
+...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/195500
More information about the Mlir-commits
mailing list