[Mlir-commits] [mlir] [MLIR][WASM] - Introduce an importer for Wasm binaries (PR #152131)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 5 05:24:32 PDT 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h mlir/include/mlir/Target/Wasm/WasmImporter.h mlir/lib/Target/Wasm/TranslateFromWasm.cpp mlir/lib/Target/Wasm/TranslateRegistration.cpp mlir/include/mlir/InitAllTranslations.h
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h
index 188c962db..5cc42a1f3 100644
--- a/mlir/include/mlir/Target/Wasm/WasmImporter.h
+++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h
@@ -27,7 +27,8 @@ namespace mlir::wasm {
/// arguments are declared at the beginning of the function.
/// If parameter 'fileId' is non-empty, then body of `emitc.file` ops
/// with matching id are emitted.
-OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext* context);
+OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source,
+ MLIRContext *context);
} // namespace mlir::wasm
#endif // MLIR_TARGET_WASM_WASMIMPORTER_H
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 7fe59a49b..dd8b86670 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -29,7 +29,8 @@ STATISTIC(numGlobalSectionItems, "Parsed globals");
STATISTIC(numMemorySectionItems, "Parsed memories");
STATISTIC(numTableSectionItems, "Parsed tables");
-static_assert(CHAR_BIT == 8, "This code expects std::byte to be exactly 8 bits");
+static_assert(CHAR_BIT == 8,
+ "This code expects std::byte to be exactly 8 bits");
using namespace mlir;
using namespace mlir::wasm;
@@ -54,7 +55,7 @@ enum struct WasmSectionType : section_id_t {
};
constexpr section_id_t highestWasmSectionID{
- static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
+ static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
#define APPLY_WASM_SEC_TRANSFORM \
WASM_SEC_TRANSFORM(CUSTOM) \
@@ -85,7 +86,7 @@ constexpr bool sectionShouldBeUnique(WasmSectionType secType) {
}
template <std::byte... Bytes>
-struct ByteSequence{};
+struct ByteSequence {};
template <std::byte... Bytes1, std::byte... Bytes2>
constexpr ByteSequence<Bytes1..., Bytes2...>
@@ -94,7 +95,7 @@ operator+(ByteSequence<Bytes1...>, ByteSequence<Bytes2...>) {
}
/// Template class for representing a byte sequence of only one byte
-template<std::byte Byte>
+template <std::byte Byte>
struct UniqueByte : ByteSequence<Byte> {};
template <typename T, T... Values>
@@ -112,12 +113,13 @@ constexpr ByteSequence<
WasmBinaryEncoding::Type::v128>
valueTypesEncodings{};
-template<std::byte... allowedFlags>
-constexpr bool isValueOneOf(std::byte value, ByteSequence<allowedFlags...> = {}) {
- return ((value == allowedFlags) | ... | false);
+template <std::byte... allowedFlags>
+constexpr bool isValueOneOf(std::byte value,
+ ByteSequence<allowedFlags...> = {}) {
+ return ((value == allowedFlags) | ... | false);
}
-template<std::byte... flags>
+template <std::byte... flags>
constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) {
return !isValueOneOf<flags...>(value);
}
@@ -143,7 +145,8 @@ struct FunctionSymbolRefContainer : SymbolRefContainer {
FunctionType functionType;
};
-using ImportDesc = std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
+using ImportDesc =
+ std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
using parsed_inst_t = FailureOr<SmallVector<Value>>;
@@ -190,6 +193,7 @@ private:
size_t stackIdx;
LabelLevelOpInterface levelOp;
};
+
public:
bool empty() const { return values.empty(); }
@@ -204,7 +208,7 @@ public:
/// if an error occurs.
/// @return Failure or the vector of popped values.
FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes,
- Location *opLoc);
+ Location *opLoc);
/// Push the results of an operation to the stack so they can be used in a
/// following operation.
@@ -214,7 +218,6 @@ public:
/// if an error occurs.
LogicalResult pushResults(ValueRange results, Location *opLoc);
-
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// A simple dump function for debugging.
/// Writes output to llvm::dbgs().
@@ -243,7 +246,6 @@ private:
parseConstInst(OpBuilder &builder,
std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr);
-
/// This function generates a dispatch tree to associate an opcode with a
/// parser. Parsers are registered by specialising the
/// `parseSpecificInstruction` function for the op code to handle.
@@ -266,8 +268,9 @@ private:
constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
constexpr size_t nextPatternBitSize = patternBitSize + 1;
if ((opCode & bitSelect) != std::byte{0})
- return dispatchToInstParser < nextPatternBitSize,
- nextHighBitPatternStem | std::byte{1} > (opCode, builder);
+ return dispatchToInstParser<nextPatternBitSize,
+ nextHighBitPatternStem | std::byte{1}>(
+ opCode, builder);
return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
opCode, builder);
} else {
@@ -281,35 +284,35 @@ private:
};
public:
- template<std::byte ParseEndByte = WasmBinaryEncoding::endByte>
- parsed_inst_t parse(OpBuilder &builder,
- UniqueByte<ParseEndByte> = {});
+ template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
+ parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
template <std::byte... ExpressionParseEnd>
FailureOr<ParseResultWithInfo>
parse(OpBuilder &builder,
ByteSequence<ExpressionParseEnd...> parsingEndFilters);
- FailureOr<SmallVector<Value>>
- popOperands(TypeRange operandTypes) {
+ FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) {
return valueStack.popOperands(operandTypes, ¤tOpLoc.value());
}
LogicalResult pushResults(ValueRange results) {
return valueStack.pushResults(results, ¤tOpLoc.value());
}
+
private:
std::optional<Location> currentOpLoc;
ParserHead &parser;
WasmModuleSymbolTables const &symbols;
locals_t locals;
ValueStack valueStack;
- };
+};
class ParserHead {
public:
ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {}
ParserHead(ParserHead &&) = default;
+
private:
ParserHead(ParserHead const &other) = default;
@@ -458,8 +461,7 @@ public:
return failure();
if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType)
return emitError(typeLoc, "invalid function type header byte. Expecting ")
- << std::to_integer<unsigned>(
- WasmBinaryEncoding::Type::funcType)
+ << std::to_integer<unsigned>(WasmBinaryEncoding::Type::funcType)
<< " got " << std::to_integer<unsigned>(*funcTypeHeader);
FailureOr<TupleType> inputTypes = parseResultType(ctx);
if (failed(inputTypes))
@@ -525,9 +527,7 @@ public:
bool end() const { return curHead().empty(); }
- ParserHead copy() const {
- return *this;
- }
+ ParserHead copy() const { return *this; }
private:
StringRef curHead() const { return head.drop_front(offset); }
@@ -575,7 +575,7 @@ FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
unsigned encodingSize{0};
StringRef src = curHead();
uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
- src.bytes_end(), &error);
+ src.bytes_end(), &error);
if (error)
return emitError(getLocation(), error);
@@ -594,7 +594,7 @@ FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
unsigned encodingSize{0};
StringRef src = curHead();
int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
- src.bytes_end(), &error);
+ src.bytes_end(), &error);
if (error)
return emitError(getLocation(), error);
if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) ||
@@ -612,7 +612,7 @@ FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
unsigned encodingSize{0};
StringRef src = curHead();
int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
- src.bytes_end(), &error);
+ src.bytes_end(), &error);
if (error)
return emitError(getLocation(), error);
@@ -648,7 +648,7 @@ void ValueStack::dump() const {
// end of the vector. Iterate in reverse so that the first thing we print
// is the top of the stack.
size_t stackSize = size();
- for (size_t idx = 0 ; idx < stackSize ;) {
+ for (size_t idx = 0; idx < stackSize;) {
size_t actualIdx = stackSize - 1 - idx;
llvm::dbgs() << " ";
values[actualIdx].dump();
@@ -676,8 +676,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
Value operand = values[i + stackIdxOffset];
Type stackType = operand.getType();
if (stackType != operandTypes[i])
- return emitError(*opLoc,
- "invalid operand type on stack. Expecting ")
+ return emitError(*opLoc, "invalid operand type on stack. Expecting ")
<< operandTypes[i] << ", value on stack is of type " << stackType
<< ".";
LLVM_DEBUG(llvm::dbgs() << " POP: " << operand << "\n");
@@ -705,8 +704,9 @@ LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
return success();
}
-template<std::byte EndParseByte>
-parsed_inst_t ExpressionParser::parse(OpBuilder &builder, UniqueByte<EndParseByte> endByte) {
+template <std::byte EndParseByte>
+parsed_inst_t ExpressionParser::parse(OpBuilder &builder,
+ UniqueByte<EndParseByte> endByte) {
auto res = parse(builder, ByteSequence<EndParseByte>{});
if (failed(res))
return failure();
@@ -735,7 +735,6 @@ ExpressionParser::parse(OpBuilder &builder,
}
}
-
template <typename T>
inline Type buildLiteralType(OpBuilder &);
@@ -769,7 +768,8 @@ inline Type buildLiteralType<double>(OpBuilder &builder) {
return builder.getF64Type();
}
-template<typename ValT, typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
+template <typename ValT,
+ typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
struct AttrHolder;
template <typename ValT>
@@ -782,7 +782,7 @@ struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
using type = FloatAttr;
};
-template<typename ValT>
+template <typename ValT>
using attr_holder_t = typename AttrHolder<ValT>::type;
template <typename ValT,
@@ -826,13 +826,13 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
return parseConstInst<double>(builder);
}
-
class WasmBinaryParser {
private:
struct SectionRegistry {
using section_location_t = StringRef;
- std::array<SmallVector<section_location_t>, highestWasmSectionID+1> registry;
+ std::array<SmallVector<section_location_t>, highestWasmSectionID + 1>
+ registry;
template <WasmSectionType SecType>
std::conditional_t<sectionShouldBeUnique(SecType),
@@ -890,7 +890,6 @@ private:
if (failed(registration))
return failure();
-
}
return success();
}
@@ -947,8 +946,8 @@ private:
<< " type registration.";
FunctionType type = symbols.moduleFuncTypes[tid.id];
std::string symbol = symbols.getNewFuncSymbolName();
- auto funcOp = builder.create<FuncImportOp>(
- loc, symbol, moduleName, importName, type);
+ auto funcOp =
+ builder.create<FuncImportOp>(loc, symbol, moduleName, importName, type);
symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type});
return funcOp.verify();
}
@@ -975,13 +974,13 @@ private:
/// Handles the registration of a global variable import
LogicalResult visitImport(Location loc, StringRef moduleName,
- StringRef importName,
- GlobalTypeRecord globalType) {
+ StringRef importName, GlobalTypeRecord globalType) {
std::string symbol = symbols.getNewGlobalSymbolName();
auto giOp =
builder.create<GlobalImportOp>(loc, symbol, moduleName, importName,
globalType.type, globalType.isMutable);
- symbols.globalSymbols.push_back({{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
+ symbols.globalSymbols.push_back(
+ {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
return giOp.verify();
}
@@ -996,20 +995,20 @@ public:
uint32_t sourceBufId = sourceMgr.getMainFileID();
StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
srcName = StringAttr::get(
- ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
+ ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
auto parser = ParserHead{source, srcName};
auto const wasmHeader = StringRef{"\0asm", 4};
FileLineColLoc magicLoc = parser.getLocation();
FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
if (failed(magic) || magic->compare(wasmHeader)) {
- emitError(magicLoc,
- "source file does not contain valid Wasm header.");
+ emitError(magicLoc, "source file does not contain valid Wasm header.");
return;
}
auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
FileLineColLoc versionLoc = parser.getLocation();
- FailureOr<StringRef> version = parser.consumeNBytes(expectedVersionString.size());
+ FailureOr<StringRef> version =
+ parser.consumeNBytes(expectedVersionString.size());
if (failed(version))
return;
if (version->compare(expectedVersionString)) {
@@ -1022,8 +1021,7 @@ public:
return;
mOp = builder.create<ModuleOp>(getLocation());
- builder.setInsertionPointToStart(
- &mOp.getBodyRegion().front());
+ builder.setInsertionPointToStart(&mOp.getBodyRegion().front());
LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>();
if (failed(parsingTypes))
return;
@@ -1071,7 +1069,8 @@ private:
template <>
LogicalResult
-WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph, size_t) {
+WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph,
+ size_t) {
FileLineColLoc importLoc = ph.getLocation();
auto moduleName = ph.parseName();
if (failed(moduleName))
@@ -1110,10 +1109,9 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
if (failed(idx))
return failure();
- using SymbolRefDesc =
- std::variant<SmallVector<SymbolRefContainer>,
- SmallVector<GlobalSymbolRefContainer>,
- SmallVector<FunctionSymbolRefContainer>>;
+ using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>,
+ SmallVector<GlobalSymbolRefContainer>,
+ SmallVector<FunctionSymbolRefContainer>>;
SymbolRefDesc currentSymbolList;
std::string symbolType = "";
@@ -1164,7 +1162,8 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
template <>
LogicalResult
-WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph, size_t) {
+WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph,
+ size_t) {
FileLineColLoc opLocation = ph.getLocation();
FailureOr<TableType> tableType = ph.parseTableType(ctx);
if (failed(tableType))
@@ -1172,7 +1171,8 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph, size_
LLVM_DEBUG(llvm::dbgs() << " Parsed table description: " << *tableType
<< '\n');
StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
- auto tableOp = builder.create<TableOp>(opLocation, symbol.strref(), *tableType);
+ auto tableOp =
+ builder.create<TableOp>(opLocation, symbol.strref(), *tableType);
symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)});
return success();
}
@@ -1216,7 +1216,8 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
template <>
LogicalResult
-WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph, size_t) {
+WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph,
+ size_t) {
FileLineColLoc opLocation = ph.getLocation();
FailureOr<LimitType> memory = ph.parseLimit(ctx);
if (failed(memory))
diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
index 9c0f7702a..03b97846d 100644
--- a/mlir/lib/Target/Wasm/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
@@ -11,18 +11,18 @@
#include "mlir/Target/Wasm/WasmImporter.h"
#include "mlir/Tools/mlir-translate/Translation.h"
-
using namespace mlir;
namespace mlir {
void registerFromWasmTranslation() {
TranslateToMLIRRegistration registration{
- "import-wasm", "Translate WASM to MLIR",
- [](llvm::SourceMgr &sourceMgr, MLIRContext* context) -> OwningOpRef<Operation*> {
- return wasm::importWebAssemblyToModule(sourceMgr, context);
- }, [](DialectRegistry& registry) {
- registry.insert<wasmssa::WasmSSADialect>();
- }
- };
+ "import-wasm", "Translate WASM to MLIR",
+ [](llvm::SourceMgr &sourceMgr,
+ MLIRContext *context) -> OwningOpRef<Operation *> {
+ return wasm::importWebAssemblyToModule(sourceMgr, context);
+ },
+ [](DialectRegistry ®istry) {
+ registry.insert<wasmssa::WasmSSADialect>();
+ }};
}
} // namespace mlir
``````````
</details>
https://github.com/llvm/llvm-project/pull/152131
More information about the Mlir-commits
mailing list