[Mlir-commits] [mlir] [MLIR][WASM] Introduce an importer for Wasm binaries (PR #152131)
Luc Forget
llvmlistbot at llvm.org
Wed Aug 13 22:25:47 PDT 2025
https://github.com/lforg37 updated https://github.com/llvm/llvm-project/pull/152131
>From b6cc91e97b401de7724ac17dd0ca24f4e492398b Mon Sep 17 00:00:00 2001
From: Luc Forget <dev at alias.lforget.fr>
Date: Mon, 30 Jun 2025 15:46:25 +0200
Subject: [PATCH 01/14] [mlir][wasm] Adding wasm import target to
mlir-translate
This commit contains basic parsing infrastructure + base code to parse
wasm binary file type section.
---------
Co-authored-by: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Co-authored-by: Jessica Paquette <jessica.paquette at woven-planet.global>
---
mlir/include/mlir/InitAllTranslations.h | 3 +-
.../mlir/Target/Wasm/WasmBinaryEncoding.h | 55 ++
mlir/include/mlir/Target/Wasm/WasmImporter.h | 35 +
mlir/lib/Target/CMakeLists.txt | 1 +
mlir/lib/Target/Wasm/CMakeLists.txt | 13 +
mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 797 ++++++++++++++++++
.../lib/Target/Wasm/TranslateRegistration.cpp | 28 +
mlir/test/Target/Wasm/bad_wasm_version.yaml | 8 +
mlir/test/Target/Wasm/import.mlir | 19 +
mlir/test/Target/Wasm/inputs/import.yaml.wasm | 44 +
mlir/test/Target/Wasm/inputs/stats.yaml.wasm | 38 +
.../Wasm/invalid_function_type_index.yaml | 18 +
mlir/test/Target/Wasm/missing_header.yaml | 12 +
mlir/test/Target/Wasm/stats.mlir | 19 +
14 files changed, 1089 insertions(+), 1 deletion(-)
create mode 100644 mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
create mode 100644 mlir/include/mlir/Target/Wasm/WasmImporter.h
create mode 100644 mlir/lib/Target/Wasm/CMakeLists.txt
create mode 100644 mlir/lib/Target/Wasm/TranslateFromWasm.cpp
create mode 100644 mlir/lib/Target/Wasm/TranslateRegistration.cpp
create mode 100644 mlir/test/Target/Wasm/bad_wasm_version.yaml
create mode 100644 mlir/test/Target/Wasm/import.mlir
create mode 100644 mlir/test/Target/Wasm/inputs/import.yaml.wasm
create mode 100644 mlir/test/Target/Wasm/inputs/stats.yaml.wasm
create mode 100644 mlir/test/Target/Wasm/invalid_function_type_index.yaml
create mode 100644 mlir/test/Target/Wasm/missing_header.yaml
create mode 100644 mlir/test/Target/Wasm/stats.mlir
diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h
index 1ab80fb27fa9a..cf8f108b88159 100644
--- a/mlir/include/mlir/InitAllTranslations.h
+++ b/mlir/include/mlir/InitAllTranslations.h
@@ -17,9 +17,9 @@
#include "mlir/Target/IRDLToCpp/TranslationRegistration.h"
namespace mlir {
-
void registerFromLLVMIRTranslation();
void registerFromSPIRVTranslation();
+void registerFromWasmTranslation();
void registerToCppTranslation();
void registerToLLVMIRTranslation();
void registerToSPIRVTranslation();
@@ -36,6 +36,7 @@ inline void registerAllTranslations() {
registerFromLLVMIRTranslation();
registerFromSPIRVTranslation();
registerIRDLToCppTranslation();
+ registerFromWasmTranslation();
registerToCppTranslation();
registerToLLVMIRTranslation();
registerToSPIRVTranslation();
diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
new file mode 100644
index 0000000000000..e01193e47fdea
--- /dev/null
+++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
@@ -0,0 +1,55 @@
+//===- WasmBinaryEncoding.h - Byte encodings for Wasm binary format ===----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// Define encodings for WebAssembly instructions, types, etc from the
+// WebAssembly binary format.
+//
+// Each encoding is defined in the WebAssembly binary specification.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_TARGET_WASMBINARYENCODING
+#define MLIR_TARGET_WASMBINARYENCODING
+
+#include <cstddef>
+namespace mlir {
+struct WasmBinaryEncoding {
+ /// Byte encodings of types in WASM binaries
+ struct Type {
+ static constexpr std::byte emptyBlockType{0x40};
+ static constexpr std::byte funcType{0x60};
+ static constexpr std::byte externRef{0x6F};
+ static constexpr std::byte funcRef{0x70};
+ static constexpr std::byte v128{0x7B};
+ static constexpr std::byte f64{0x7C};
+ static constexpr std::byte f32{0x7D};
+ static constexpr std::byte i64{0x7E};
+ static constexpr std::byte i32{0x7F};
+ };
+
+ /// Byte encodings of WASM imports.
+ struct Import {
+ static constexpr std::byte typeID{0x00};
+ static constexpr std::byte tableType{0x01};
+ static constexpr std::byte memType{0x02};
+ static constexpr std::byte globalType{0x03};
+ };
+
+ /// Byte encodings for WASM limits.
+ struct LimitHeader {
+ static constexpr std::byte lowLimitOnly{0x00};
+ static constexpr std::byte bothLimits{0x01};
+ };
+
+ /// Byte encodings describing the mutability of globals.
+ struct GlobalMutability {
+ static constexpr std::byte isConst{0x00};
+ static constexpr std::byte isMutable{0x01};
+ };
+
+};
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h
new file mode 100644
index 0000000000000..fc7d275353964
--- /dev/null
+++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h
@@ -0,0 +1,35 @@
+//===- WasmImporter.h - Helpers to create WebAssembly emitter ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines helpers to import WebAssembly code using the WebAssembly
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_WASM_WASMIMPORTER_H
+#define MLIR_TARGET_WASM_WASMIMPORTER_H
+
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir {
+namespace wasm {
+
+/// Translates the given operation to C++ code. The operation or operations in
+/// the region of 'op' need almost all be in EmitC dialect. The parameter
+/// 'declareVariablesAtTop' enforces that all variables for op results and block
+/// arguments are declared at the beginning of the function.
+/// If parameter 'fileId' is non-empty, then body of `emitc.file` ops
+/// with matching id are emitted.
+OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext* context);
+} // namespace wasm
+} // namespace mlir
+
+#endif // MLIR_TARGET_WASM_WASMIMPORTER_H
diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index 6eb0abc214d38..f0c3ac4d511c1 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -4,3 +4,4 @@ add_subdirectory(SPIRV)
add_subdirectory(LLVMIR)
add_subdirectory(LLVM)
add_subdirectory(SMTLIB)
+add_subdirectory(Wasm)
diff --git a/mlir/lib/Target/Wasm/CMakeLists.txt b/mlir/lib/Target/Wasm/CMakeLists.txt
new file mode 100644
index 0000000000000..890fc0ecfbeb6
--- /dev/null
+++ b/mlir/lib/Target/Wasm/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_translation_library(MLIRTargetWasmImport
+ TranslateRegistration.cpp
+ TranslateFromWasm.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/Target/Wasm
+
+ LINK_LIBS PUBLIC
+ MLIRWasmSSADialect
+ MLIRIR
+ MLIRSupport
+ MLIRTranslateLib
+)
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
new file mode 100644
index 0000000000000..2962cc212f848
--- /dev/null
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -0,0 +1,797 @@
+//===- TranslateFromWasm.cpp - Translating to C++ calls -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/Target/Wasm/WasmBinaryEncoding.h"
+#include "mlir/Target/Wasm/WasmImporter.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/LEB128.h"
+
+#include <variant>
+
+#define DEBUG_TYPE "wasm-translate"
+
+// Statistics.
+STATISTIC(numFunctionSectionItems, "Parsed functions");
+STATISTIC(numGlobalSectionItems, "Parsed globals");
+STATISTIC(numMemorySectionItems, "Parsed memories");
+STATISTIC(numTableSectionItems, "Parsed tables");
+
+static_assert(CHAR_BIT == 8, "This code expects std::byte to be exactly 8 bits");
+
+using namespace mlir;
+using namespace mlir::wasm;
+using namespace mlir::wasmssa;
+
+namespace {
+using section_id_t = uint8_t;
+enum struct WasmSectionType : section_id_t {
+ CUSTOM = 0,
+ TYPE = 1,
+ IMPORT = 2,
+ FUNCTION = 3,
+ TABLE = 4,
+ MEMORY = 5,
+ GLOBAL = 6,
+ EXPORT = 7,
+ START = 8,
+ ELEMENT = 9,
+ CODE = 10,
+ DATA = 11,
+ DATACOUNT = 12
+};
+
+constexpr section_id_t highestWasmSectionID{
+ static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
+
+#define APPLY_WASM_SEC_TRANSFORM \
+ WASM_SEC_TRANSFORM(CUSTOM) \
+ WASM_SEC_TRANSFORM(TYPE) \
+ WASM_SEC_TRANSFORM(IMPORT) \
+ WASM_SEC_TRANSFORM(FUNCTION) \
+ WASM_SEC_TRANSFORM(TABLE) \
+ WASM_SEC_TRANSFORM(MEMORY) \
+ WASM_SEC_TRANSFORM(GLOBAL) \
+ WASM_SEC_TRANSFORM(EXPORT) \
+ WASM_SEC_TRANSFORM(START) \
+ WASM_SEC_TRANSFORM(ELEMENT) \
+ WASM_SEC_TRANSFORM(CODE) \
+ WASM_SEC_TRANSFORM(DATA) \
+ WASM_SEC_TRANSFORM(DATACOUNT)
+
+template <WasmSectionType>
+constexpr const char *wasmSectionName = "";
+
+#define WASM_SEC_TRANSFORM(section) \
+ template <> \
+ constexpr const char *wasmSectionName<WasmSectionType::section> = #section;
+APPLY_WASM_SEC_TRANSFORM
+#undef WASM_SEC_TRANSFORM
+
+constexpr bool sectionShouldBeUnique(WasmSectionType secType) {
+ return secType != WasmSectionType::CUSTOM;
+}
+
+template <std::byte... Bytes>
+struct ByteSequence{};
+
+template <std::byte... Bytes1, std::byte... Bytes2>
+constexpr ByteSequence<Bytes1..., Bytes2...>
+operator+(ByteSequence<Bytes1...>, ByteSequence<Bytes2...>) {
+ return {};
+}
+
+/// Template class for representing a byte sequence of only one byte
+template<std::byte Byte>
+struct UniqueByte : ByteSequence<Byte> {};
+
+template <typename T, T... Values>
+constexpr ByteSequence<std::byte{Values}...>
+byteSeqFromIntSeq(std::integer_sequence<T, Values...>) {
+ return {};
+}
+
+constexpr auto allOpCodes =
+ byteSeqFromIntSeq(std::make_integer_sequence<int, 256>());
+
+constexpr ByteSequence<
+ WasmBinaryEncoding::Type::i32, WasmBinaryEncoding::Type::i64,
+ WasmBinaryEncoding::Type::f32, WasmBinaryEncoding::Type::f64,
+ WasmBinaryEncoding::Type::v128>
+ valueTypesEncodings{};
+
+template<std::byte... allowedFlags>
+constexpr bool isValueOneOf(std::byte value, ByteSequence<allowedFlags...> = {}) {
+ return ((value == allowedFlags) | ... | false);
+}
+
+template<std::byte... flags>
+constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) {
+ return !isValueOneOf<flags...>(value);
+}
+
+struct GlobalTypeRecord {
+ Type type;
+ bool isMutable;
+};
+
+struct TypeIdxRecord {
+ size_t id;
+};
+
+struct SymbolRefContainer {
+ FlatSymbolRefAttr symbol;
+};
+
+struct GlobalSymbolRefContainer : SymbolRefContainer {
+ Type globalType;
+};
+
+struct FunctionSymbolRefContainer : SymbolRefContainer {
+ FunctionType functionType;
+};
+
+using ImportDesc = std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
+
+struct WasmModuleSymbolTables {
+ llvm::SmallVector<FunctionSymbolRefContainer> funcSymbols;
+ llvm::SmallVector<GlobalSymbolRefContainer> globalSymbols;
+ llvm::SmallVector<SymbolRefContainer> memSymbols;
+ llvm::SmallVector<SymbolRefContainer> tableSymbols;
+ llvm::SmallVector<FunctionType> moduleFuncTypes;
+
+ std::string getNewSymbolName(llvm::StringRef prefix, size_t id) const {
+ return (prefix + llvm::Twine{id}).str();
+ }
+
+ std::string getNewFuncSymbolName() const {
+ auto id = funcSymbols.size();
+ return getNewSymbolName("func_", id);
+ }
+
+ std::string getNewGlobalSymbolName() const {
+ auto id = globalSymbols.size();
+ return getNewSymbolName("global_", id);
+ }
+
+ std::string getNewMemorySymbolName() const {
+ auto id = memSymbols.size();
+ return getNewSymbolName("mem_", id);
+ }
+
+ std::string getNewTableSymbolName() const {
+ auto id = tableSymbols.size();
+ return getNewSymbolName("table_", id);
+ }
+};
+class ParserHead {
+public:
+ ParserHead(llvm::StringRef src, StringAttr name) : head{src}, locName{name} {}
+ ParserHead(ParserHead &&) = default;
+private:
+ ParserHead(ParserHead const &other) = default;
+
+public:
+ auto getLocation() const {
+ return FileLineColLoc::get(locName, 0, anchorOffset + offset);
+ }
+
+ llvm::FailureOr<llvm::StringRef> consumeNBytes(size_t nBytes) {
+ LLVM_DEBUG(llvm::dbgs() << "Consume " << nBytes << " bytes\n");
+ LLVM_DEBUG(llvm::dbgs() << " Bytes remaining: " << size() << "\n");
+ LLVM_DEBUG(llvm::dbgs() << " Current offset: " << offset << "\n");
+ if (nBytes > size())
+ return emitError(getLocation(), "trying to extract ")
+ << nBytes << "bytes when only " << size() << "are avilables";
+
+ auto res = head.slice(offset, offset + nBytes);
+ offset += nBytes;
+ LLVM_DEBUG(llvm::dbgs()
+ << " Updated offset (+" << nBytes << "): " << offset << "\n");
+ return res;
+ }
+
+ llvm::FailureOr<std::byte> consumeByte() {
+ auto res = consumeNBytes(1);
+ if (failed(res))
+ return failure();
+ return std::byte{*res->bytes_begin()};
+ }
+
+ template <typename T>
+ llvm::FailureOr<T> parseLiteral();
+
+ llvm::FailureOr<uint32_t> parseVectorSize();
+
+private:
+ // TODO: This is equivalent to parseLiteral<uint32_t> and could be removed
+ // if parseLiteral specialization were moved here, but default GCC on Ubuntu
+ // 22.04 has bug with template specialization in class declaration
+ inline llvm::FailureOr<uint32_t> parseUI32();
+ inline llvm::FailureOr<int64_t> parseI64();
+
+public:
+ llvm::FailureOr<llvm::StringRef> parseName() {
+ auto size = parseVectorSize();
+ if (failed(size))
+ return failure();
+
+ return consumeNBytes(*size);
+ }
+
+ llvm::FailureOr<WasmSectionType> parseWasmSectionType() {
+ auto id = consumeByte();
+ if (failed(id))
+ return failure();
+ if (std::to_integer<unsigned>(*id) > highestWasmSectionID)
+ return emitError(getLocation(), "Invalid section ID: ")
+ << static_cast<int>(*id);
+ return static_cast<WasmSectionType>(*id);
+ }
+
+ llvm::FailureOr<LimitType> parseLimit(MLIRContext *ctx) {
+ using WasmLimits = WasmBinaryEncoding::LimitHeader;
+ auto limitLocation = getLocation();
+ auto limitHeader = consumeByte();
+ if (failed(limitHeader))
+ return failure();
+
+ if (isNotIn<WasmLimits::bothLimits, WasmLimits::lowLimitOnly>(*limitHeader))
+ return emitError(limitLocation, "Invalid limit header: ")
+ << static_cast<int>(*limitHeader);
+ auto minParse = parseUI32();
+ if (failed(minParse))
+ return failure();
+ std::optional<uint32_t> max{std::nullopt};
+ if (*limitHeader == WasmLimits::bothLimits) {
+ auto maxParse = parseUI32();
+ if (failed(maxParse))
+ return failure();
+ max = *maxParse;
+ }
+ return LimitType::get(ctx, *minParse, max);
+ }
+
+ llvm::FailureOr<Type> parseValueType(MLIRContext *ctx) {
+ auto typeLoc = getLocation();
+ auto typeEncoding = consumeByte();
+ if (failed(typeEncoding))
+ return failure();
+ switch (*typeEncoding) {
+ case WasmBinaryEncoding::Type::i32:
+ return IntegerType::get(ctx, 32);
+ case WasmBinaryEncoding::Type::i64:
+ return IntegerType::get(ctx, 64);
+ case WasmBinaryEncoding::Type::f32:
+ return Float32Type::get(ctx);
+ case WasmBinaryEncoding::Type::f64:
+ return Float64Type::get(ctx);
+ case WasmBinaryEncoding::Type::v128:
+ return IntegerType::get(ctx, 128);
+ case WasmBinaryEncoding::Type::funcRef:
+ return wasmssa::FuncRefType::get(ctx);
+ case WasmBinaryEncoding::Type::externRef:
+ return wasmssa::ExternRefType::get(ctx);
+ default:
+ return emitError(typeLoc, "Invalid value type encoding: ")
+ << static_cast<int>(*typeEncoding);
+ }
+ }
+
+ llvm::FailureOr<GlobalTypeRecord> parseGlobalType(MLIRContext *ctx) {
+ using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability;
+ auto typeParsed = parseValueType(ctx);
+ if (failed(typeParsed))
+ return failure();
+ auto mutLoc = getLocation();
+ auto mutSpec = consumeByte();
+ if (failed(mutSpec))
+ return failure();
+ if (isNotIn<WasmGlobalMut::isConst, WasmGlobalMut::isMutable>(*mutSpec))
+ return emitError(mutLoc, "Invalid global mutability specifier: ")
+ << static_cast<int>(*mutSpec);
+ return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable};
+ }
+
+ llvm::FailureOr<TupleType> parseResultType(MLIRContext *ctx) {
+ auto nParamsParsed = parseVectorSize();
+ if (failed(nParamsParsed))
+ return failure();
+ auto nParams = *nParamsParsed;
+ llvm::SmallVector<Type> res{};
+ res.reserve(nParams);
+ for (size_t i = 0; i < nParams; ++i) {
+ auto parsedType = parseValueType(ctx);
+ if (failed(parsedType))
+ return failure();
+ res.push_back(*parsedType);
+ }
+ return TupleType::get(ctx, res);
+ }
+
+ llvm::FailureOr<FunctionType> parseFunctionType(MLIRContext *ctx) {
+ auto typeLoc = getLocation();
+ auto funcTypeHeader = consumeByte();
+ if (failed(funcTypeHeader))
+ return failure();
+ if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType)
+ return emitError(typeLoc, "Invalid function type header byte. Expecting ")
+ << std::to_integer<unsigned>(
+ WasmBinaryEncoding::Type::funcType)
+ << " got " << std::to_integer<unsigned>(*funcTypeHeader);
+ auto inputTypes = parseResultType(ctx);
+ if (failed(inputTypes))
+ return failure();
+
+ auto resTypes = parseResultType(ctx);
+ if (failed(resTypes))
+ return failure();
+
+ return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes());
+ }
+
+ llvm::FailureOr<TypeIdxRecord> parseTypeIndex() {
+ auto res = parseUI32();
+ if (failed(res))
+ return failure();
+ return TypeIdxRecord{*res};
+ }
+
+ llvm::FailureOr<TableType> parseTableType(MLIRContext *ctx) {
+ auto elmTypeParse = parseValueType(ctx);
+ if (failed(elmTypeParse))
+ return failure();
+ if (!isWasmRefType(*elmTypeParse))
+ return emitError(getLocation(), "Invalid element type for table");
+ auto limitParse = parseLimit(ctx);
+ if (failed(limitParse))
+ return failure();
+ return TableType::get(ctx, *elmTypeParse, *limitParse);
+ }
+
+ llvm::FailureOr<ImportDesc> parseImportDesc(MLIRContext *ctx) {
+ auto importLoc = getLocation();
+ auto importType = consumeByte();
+ auto packager = [](auto parseResult) -> llvm::FailureOr<ImportDesc> {
+ if (llvm::failed(parseResult))
+ return failure();
+ return {*parseResult};
+ };
+ if (failed(importType))
+ return failure();
+ switch (*importType) {
+ case WasmBinaryEncoding::Import::typeID:
+ return packager(parseTypeIndex());
+ case WasmBinaryEncoding::Import::tableType:
+ return packager(parseTableType(ctx));
+ case WasmBinaryEncoding::Import::memType:
+ return packager(parseLimit(ctx));
+ case WasmBinaryEncoding::Import::globalType:
+ return packager(parseGlobalType(ctx));
+ default:
+ return emitError(importLoc, "Invalid import type descriptor: ")
+ << static_cast<int>(*importType);
+ }
+ }
+ bool end() const { return curHead().empty(); }
+
+ ParserHead copy() const {
+ return *this;
+ }
+
+private:
+ llvm::StringRef curHead() const { return head.drop_front(offset); }
+
+ llvm::FailureOr<std::byte> peek() const {
+ if (end())
+ return emitError(
+ getLocation(),
+ "trying to peek at next byte, but input stream is empty");
+ return static_cast<std::byte>(curHead().front());
+ }
+
+ size_t size() const { return head.size() - offset; }
+
+ llvm::StringRef head;
+ StringAttr locName;
+ unsigned anchorOffset{0};
+ unsigned offset{0};
+};
+
+template <>
+llvm::FailureOr<float> ParserHead::parseLiteral<float>() {
+ auto bytes = consumeNBytes(4);
+ if (failed(bytes))
+ return failure();
+ float result;
+ std::memcpy(&result, bytes->bytes_begin(), 4);
+ return result;
+}
+
+template <>
+llvm::FailureOr<double> ParserHead::parseLiteral<double>() {
+ auto bytes = consumeNBytes(8);
+ if (failed(bytes))
+ return failure();
+ double result;
+ std::memcpy(&result, bytes->bytes_begin(), 8);
+ return result;
+}
+
+template <>
+llvm::FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
+ char const *error = nullptr;
+ uint32_t res{0};
+ unsigned encodingSize{0};
+ auto src = curHead();
+ auto decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+
+ if (std::isgreater(decoded, std::numeric_limits<uint32_t>::max()))
+ return emitError(getLocation()) << "literal does not fit on 32 bits";
+
+ res = static_cast<uint32_t>(decoded);
+ offset += encodingSize;
+ return res;
+}
+
+template <>
+llvm::FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
+ char const *error = nullptr;
+ int32_t res{0};
+ unsigned encodingSize{0};
+ auto src = curHead();
+ auto decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+ if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) ||
+ std::isgreater(std::numeric_limits<int32_t>::min(), decoded))
+ return emitError(getLocation()) << "literal does not fit on 32 bits";
+
+ res = static_cast<int32_t>(decoded);
+ offset += encodingSize;
+ return res;
+}
+
+template <>
+llvm::FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
+ char const *error = nullptr;
+ unsigned encodingSize{0};
+ auto src = curHead();
+ auto res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+
+ offset += encodingSize;
+ return res;
+}
+
+llvm::FailureOr<uint32_t> ParserHead::parseVectorSize() {
+ return parseLiteral<uint32_t>();
+}
+
+inline llvm::FailureOr<uint32_t> ParserHead::parseUI32() {
+ return parseLiteral<uint32_t>();
+}
+
+inline llvm::FailureOr<int64_t> ParserHead::parseI64() {
+ return parseLiteral<int64_t>();
+}
+
+class WasmBinaryParser {
+private:
+ struct SectionRegistry {
+ using section_location_t = llvm::StringRef;
+
+ std::array<llvm::SmallVector<section_location_t>, highestWasmSectionID+1> registry;
+
+ template <WasmSectionType SecType>
+ std::conditional_t<sectionShouldBeUnique(SecType),
+ std::optional<section_location_t>,
+ llvm::ArrayRef<section_location_t>>
+ getContentForSection() const {
+ constexpr auto idx = static_cast<size_t>(SecType);
+ if constexpr (sectionShouldBeUnique(SecType)) {
+ return registry[idx].empty() ? std::nullopt
+ : std::make_optional(registry[idx][0]);
+ } else {
+ return registry[idx];
+ }
+ }
+
+ bool hasSection(WasmSectionType secType) const {
+ return !registry[static_cast<size_t>(secType)].empty();
+ }
+
+ ///
+ /// @returns success if registration valid, failure in case registration
+ /// can't be done (if another section of same type already exist and this
+ /// section type should only be present once)
+ ///
+ LogicalResult registerSection(WasmSectionType secType,
+ section_location_t location, Location loc) {
+ if (sectionShouldBeUnique(secType) && hasSection(secType))
+ return emitError(loc,
+ "Trying to add a second instance of unique section");
+
+ registry[static_cast<size_t>(secType)].push_back(location);
+ emitRemark(loc, "Adding section with section ID ")
+ << static_cast<uint8_t>(secType);
+ return success();
+ }
+
+ LogicalResult populateFromBody(ParserHead ph) {
+ while (!ph.end()) {
+ auto sectionLoc = ph.getLocation();
+ auto secType = ph.parseWasmSectionType();
+ if (failed(secType))
+ return failure();
+
+ auto secSizeParsed = ph.parseLiteral<uint32_t>();
+ if (failed(secSizeParsed))
+ return failure();
+
+ auto secSize = *secSizeParsed;
+ auto sectionContent = ph.consumeNBytes(secSize);
+ if (failed(sectionContent))
+ return failure();
+
+ auto registration =
+ registerSection(*secType, *sectionContent, sectionLoc);
+
+ if (failed(registration))
+ return failure();
+
+ }
+ return success();
+ }
+ };
+
+ auto getLocation(int offset = 0) const {
+ return FileLineColLoc::get(srcName, 0, offset);
+ }
+
+ template <WasmSectionType>
+ LogicalResult parseSectionItem(ParserHead &, size_t);
+
+ template <WasmSectionType section>
+ LogicalResult parseSection() {
+ auto secName = std::string{wasmSectionName<section>};
+ auto sectionNameAttr =
+ StringAttr::get(ctx, srcName.strref() + ":" + secName + "-SECTION");
+ unsigned offset = 0;
+ auto getLocation = [sectionNameAttr, &offset]() {
+ return FileLineColLoc::get(sectionNameAttr, 0, offset);
+ };
+ auto secContent = registry.getContentForSection<section>();
+ if (!secContent) {
+ LLVM_DEBUG(llvm::dbgs() << secName << " section is not present in file.");
+ return success();
+ }
+
+ auto secSrc = secContent.value();
+ ParserHead ph{secSrc, sectionNameAttr};
+ auto nElemsParsed = ph.parseVectorSize();
+ if (failed(nElemsParsed))
+ return failure();
+ auto nElems = *nElemsParsed;
+ LLVM_DEBUG(llvm::dbgs() << "Starting to parse " << nElems
+ << " items for section " << secName << ".\n");
+ for (size_t i = 0; i < nElems; ++i) {
+ if (failed(parseSectionItem<section>(ph, i)))
+ return failure();
+ }
+
+ if (!ph.end())
+ return emitError(getLocation(), "Unparsed garbage at end of section ")
+ << secName;
+ return success();
+ }
+
+ /// Handles the registration of a function import
+ LogicalResult visitImport(Location loc, llvm::StringRef moduleName,
+ llvm::StringRef importName, TypeIdxRecord tid) {
+ using llvm::Twine;
+ if (tid.id >= symbols.moduleFuncTypes.size())
+ return emitError(loc, "Invalid type id: ")
+ << tid.id << ". Only " << symbols.moduleFuncTypes.size()
+ << " type registration.";
+ auto type = symbols.moduleFuncTypes[tid.id];
+ auto symbol = symbols.getNewFuncSymbolName();
+ auto funcOp = builder.create<FuncImportOp>(
+ loc, symbol, moduleName, importName, type);
+ symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type});
+ return funcOp.verify();
+ }
+
+ /// Handles the registration of a memory import
+ LogicalResult visitImport(Location loc, llvm::StringRef moduleName,
+ llvm::StringRef importName, LimitType limitType) {
+ auto symbol = symbols.getNewMemorySymbolName();
+ auto memOp = builder.create<MemImportOp>(loc, symbol, moduleName,
+ importName, limitType);
+ symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)});
+ return memOp.verify();
+ }
+
+ /// Handles the registration of a table import
+ LogicalResult visitImport(Location loc, llvm::StringRef moduleName,
+ llvm::StringRef importName, TableType tableType) {
+ auto symbol = symbols.getNewTableSymbolName();
+ auto tableOp = builder.create<TableImportOp>(loc, symbol, moduleName,
+ importName, tableType);
+ symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)});
+ return tableOp.verify();
+ }
+
+ /// Handles the registration of a global variable import
+ LogicalResult visitImport(Location loc, llvm::StringRef moduleName,
+ llvm::StringRef importName,
+ GlobalTypeRecord globalType) {
+ auto symbol = symbols.getNewGlobalSymbolName();
+ auto giOp =
+ builder.create<GlobalImportOp>(loc, symbol, moduleName, importName,
+ globalType.type, globalType.isMutable);
+ symbols.globalSymbols.push_back({{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
+ return giOp.verify();
+ }
+
+public:
+ WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
+ : builder{ctx}, ctx{ctx} {
+ ctx->loadAllAvailableDialects();
+ if (sourceMgr.getNumBuffers() != 1) {
+ emitError(UnknownLoc::get(ctx), "One source file should be provided");
+ return;
+ }
+ auto sourceBufId = sourceMgr.getMainFileID();
+ auto source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
+ srcName = StringAttr::get(
+ ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
+
+ auto parser = ParserHead{source, srcName};
+ auto const wasmHeader = StringRef{"\0asm", 4};
+ auto magicLoc = parser.getLocation();
+ auto magic = parser.consumeNBytes(wasmHeader.size());
+ if (failed(magic) || magic->compare(wasmHeader)) {
+ emitError(magicLoc,
+ "Source file does not contain valid Wasm header.");
+ return;
+ }
+ auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
+ auto versionLoc = parser.getLocation();
+ auto version = parser.consumeNBytes(expectedVersionString.size());
+ if (failed(version))
+ return;
+ if (version->compare(expectedVersionString)) {
+ emitError(versionLoc,
+ "Unsupported Wasm version. Only version 1 is supported.");
+ return;
+ }
+ auto fillRegistry = registry.populateFromBody(parser.copy());
+ if (failed(fillRegistry))
+ return;
+
+ mOp = builder.create<ModuleOp>(getLocation());
+ builder.setInsertionPointToStart(
+ &mOp.getBodyRegion().front());
+ auto parsingTypes = parseSection<WasmSectionType::TYPE>();
+ if (failed(parsingTypes))
+ return;
+
+ auto parsingImports = parseSection<WasmSectionType::IMPORT>();
+ if (failed(parsingImports))
+ return;
+
+ firstInternalFuncID = symbols.funcSymbols.size();
+
+ auto parsingFunctions = parseSection<WasmSectionType::FUNCTION>();
+ if (failed(parsingFunctions))
+ return;
+
+
+ // Copy over sizes of containers into statistics.
+ numFunctionSectionItems = symbols.funcSymbols.size();
+ numGlobalSectionItems = symbols.globalSymbols.size();
+ numMemorySectionItems = symbols.memSymbols.size();
+ numTableSectionItems = symbols.tableSymbols.size();
+ }
+
+ ModuleOp getModule() { return mOp; }
+
+private:
+ mlir::StringAttr srcName;
+ OpBuilder builder;
+ WasmModuleSymbolTables symbols;
+ MLIRContext *ctx;
+ ModuleOp mOp;
+ SectionRegistry registry;
+ size_t firstInternalFuncID{0};
+};
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph, size_t) {
+ auto importLoc = ph.getLocation();
+ auto moduleName = ph.parseName();
+ if (failed(moduleName))
+ return failure();
+
+ auto importName = ph.parseName();
+ if (failed(importName))
+ return failure();
+
+ auto import = ph.parseImportDesc(ctx);
+ if (failed(import))
+ return failure();
+
+ return std::visit(
+ [this, importLoc, &moduleName, &importName](auto import) {
+ return visitImport(importLoc, *moduleName, *importName, import);
+ },
+ *import);
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph,
+ size_t) {
+ auto opLoc = ph.getLocation();
+ auto typeIdxParsed = ph.parseLiteral<uint32_t>();
+ if (failed(typeIdxParsed))
+ return failure();
+ auto typeIdx = *typeIdxParsed;
+ if (typeIdx >= symbols.moduleFuncTypes.size())
+ return emitError(getLocation(), "Invalid type index: ") << typeIdx;
+ auto symbol = symbols.getNewFuncSymbolName();
+ auto funcOp =
+ builder.create<FuncOp>(opLoc, symbol, symbols.moduleFuncTypes[typeIdx]);
+ auto *block = funcOp.addEntryBlock();
+ auto ip = builder.saveInsertionPoint();
+ builder.setInsertionPointToEnd(block);
+ builder.create<ReturnOp>(opLoc);
+ builder.restoreInsertionPoint(ip);
+ symbols.funcSymbols.push_back(
+ {{FlatSymbolRefAttr::get(funcOp.getSymNameAttr())},
+ symbols.moduleFuncTypes[typeIdx]});
+ return funcOp.verify();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
+ size_t) {
+ auto funcType = ph.parseFunctionType(ctx);
+ if (failed(funcType))
+ return failure();
+ LLVM_DEBUG(llvm::dbgs() << "Parsed function type " << *funcType << '\n');
+ symbols.moduleFuncTypes.push_back(*funcType);
+ return success();
+}
+} // namespace
+
+namespace mlir {
+namespace wasm {
+OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source,
+ MLIRContext *context) {
+ WasmBinaryParser wBN{source, context};
+ auto mOp = wBN.getModule();
+ if (mOp)
+ return {mOp};
+
+ return {nullptr};
+}
+} // namespace wasm
+} // namespace mlir
diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
new file mode 100644
index 0000000000000..9c0f7702a96aa
--- /dev/null
+++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
@@ -0,0 +1,28 @@
+//===- TranslateRegistration.cpp - Register translation -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "mlir/Target/Wasm/WasmImporter.h"
+#include "mlir/Tools/mlir-translate/Translation.h"
+
+
+using namespace mlir;
+
+namespace mlir {
+void registerFromWasmTranslation() {
+ TranslateToMLIRRegistration registration{
+ "import-wasm", "Translate WASM to MLIR",
+ [](llvm::SourceMgr &sourceMgr, MLIRContext* context) -> OwningOpRef<Operation*> {
+ return wasm::importWebAssemblyToModule(sourceMgr, context);
+ }, [](DialectRegistry& registry) {
+ registry.insert<wasmssa::WasmSSADialect>();
+ }
+ };
+}
+} // namespace mlir
diff --git a/mlir/test/Target/Wasm/bad_wasm_version.yaml b/mlir/test/Target/Wasm/bad_wasm_version.yaml
new file mode 100644
index 0000000000000..4fed1d5a3af3c
--- /dev/null
+++ b/mlir/test/Target/Wasm/bad_wasm_version.yaml
@@ -0,0 +1,8 @@
+# RUN: yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s
+
+# CHECK: Unsupported Wasm version
+
+--- !WASM
+FileHeader:
+ Version: 0xDEADBEEF
+...
diff --git a/mlir/test/Target/Wasm/import.mlir b/mlir/test/Target/Wasm/import.mlir
new file mode 100644
index 0000000000000..541dcf3a2d9eb
--- /dev/null
+++ b/mlir/test/Target/Wasm/import.mlir
@@ -0,0 +1,19 @@
+// RUN: yaml2obj %S/inputs/import.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+(import "my_module" "foo" (func $foo (param i32)))
+(import "my_module" "bar" (func $bar (param i32)))
+(import "my_module" "table" (table $round 2 funcref))
+(import "my_module" "mem" (memory $mymem 2))
+(import "my_module" "glob" (global $globglob i32))
+(import "my_other_module" "glob_mut" (global $glob_mut (mut i32)))
+)
+*/
+
+// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()}
+// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()}
+// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
+// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"}
+// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32
+// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32
diff --git a/mlir/test/Target/Wasm/inputs/import.yaml.wasm b/mlir/test/Target/Wasm/inputs/import.yaml.wasm
new file mode 100644
index 0000000000000..7c467ff6fbc67
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/import.yaml.wasm
@@ -0,0 +1,44 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes: []
+ - Type: IMPORT
+ Imports:
+ - Module: my_module
+ Field: foo
+ Kind: FUNCTION
+ SigIndex: 0
+ - Module: my_module
+ Field: bar
+ Kind: FUNCTION
+ SigIndex: 0
+ - Module: my_module
+ Field: table
+ Kind: TABLE
+ Table:
+ Index: 0
+ ElemType: FUNCREF
+ Limits:
+ Minimum: 0x2
+ - Module: my_module
+ Field: mem
+ Kind: MEMORY
+ Memory:
+ Minimum: 0x2
+ - Module: my_module
+ Field: glob
+ Kind: GLOBAL
+ GlobalType: I32
+ GlobalMutable: false
+ - Module: my_other_module
+ Field: glob_mut
+ Kind: GLOBAL
+ GlobalType: I32
+ GlobalMutable: true
+...
diff --git a/mlir/test/Target/Wasm/inputs/stats.yaml.wasm b/mlir/test/Target/Wasm/inputs/stats.yaml.wasm
new file mode 100644
index 0000000000000..bf577688b3aed
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/stats.yaml.wasm
@@ -0,0 +1,38 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: TABLE
+ Tables:
+ - Index: 0
+ ElemType: FUNCREF
+ Limits:
+ Minimum: 0x2
+ - Type: MEMORY
+ Memories:
+ - Flags: [ HAS_MAX ]
+ Minimum: 0x0
+ Maximum: 0x10000
+ - Type: GLOBAL
+ Globals:
+ - Index: 0
+ Type: I32
+ Mutable: false
+ InitExpr:
+ Opcode: I32_CONST
+ Value: 10
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 20000B
+...
diff --git a/mlir/test/Target/Wasm/invalid_function_type_index.yaml b/mlir/test/Target/Wasm/invalid_function_type_index.yaml
new file mode 100644
index 0000000000000..961e9cc6e8029
--- /dev/null
+++ b/mlir/test/Target/Wasm/invalid_function_type_index.yaml
@@ -0,0 +1,18 @@
+# RUN: yaml2obj %s | mlir-translate --import-wasm -o - 2>&1 | FileCheck %s
+# CHECK: error: Invalid type index: 2
+
+# FIXME: mlir-translate should not return 0 here.
+
+--- !WASM
+FileHeader:
+ Version: 0x00000001
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes: []
+ - Type: FUNCTION
+ FunctionTypes:
+ - 2
diff --git a/mlir/test/Target/Wasm/missing_header.yaml b/mlir/test/Target/Wasm/missing_header.yaml
new file mode 100644
index 0000000000000..5610f9c5c6e33
--- /dev/null
+++ b/mlir/test/Target/Wasm/missing_header.yaml
@@ -0,0 +1,12 @@
+# RUN: not yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s
+
+# CHECK: Source file does not contain valid Wasm header
+
+--- !WASM
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes: []
+...
diff --git a/mlir/test/Target/Wasm/stats.mlir b/mlir/test/Target/Wasm/stats.mlir
new file mode 100644
index 0000000000000..e68b85d20f67d
--- /dev/null
+++ b/mlir/test/Target/Wasm/stats.mlir
@@ -0,0 +1,19 @@
+// RUN: yaml2obj %S/inputs/stats.yaml.wasm -o - | mlir-translate --import-wasm -stats 2>&1 | FileCheck %s
+// Check that we get the correct stats for a module that has a single
+// function, table, memory, and global.
+// REQUIRES: asserts
+
+/* Source code used to create this test:
+(module
+ (type (;0;) (func (param i32) (result i32)))
+ (func (;0;) (type 0) (param i32) (result i32)
+ local.get 0)
+ (table (;0;) 2 funcref)
+ (memory (;0;) 0 65536)
+ (global (;0;) i32 (i32.const 10)))
+*/
+
+// CHECK: 1 wasm-translate - Parsed functions
+// CHECK-NEXT: 0 wasm-translate - Parsed globals
+// CHECK-NEXT: 0 wasm-translate - Parsed memories
+// CHECK-NEXT: 0 wasm-translate - Parsed tables
>From b942297c12f895c2cbfcdd60d18828ddd0438018 Mon Sep 17 00:00:00 2001
From: Luc Forget <dev at alias.lforget.fr>
Date: Mon, 30 Jun 2025 19:11:09 +0200
Subject: [PATCH 02/14] [mlir][wasm] Handling table in Wasm importer
---------
Co-authored-by: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Co-authored-by: Jessica Paquette <jessica.paquette at woven-planet.global>
---
mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 18 +++++++++++++++
mlir/test/Target/Wasm/inputs/table.yaml.wasm | 23 ++++++++++++++++++++
mlir/test/Target/Wasm/stats.mlir | 2 +-
3 files changed, 42 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Target/Wasm/inputs/table.yaml.wasm
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 2962cc212f848..23f0cca1a148f 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -701,6 +701,9 @@ class WasmBinaryParser {
if (failed(parsingFunctions))
return;
+ auto parsingTables = parseSection<WasmSectionType::TABLE>();
+ if (failed(parsingTables))
+ return;
// Copy over sizes of containers into statistics.
numFunctionSectionItems = symbols.funcSymbols.size();
@@ -744,6 +747,21 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph, size
*import);
}
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph, size_t) {
+ auto opLocation = ph.getLocation();
+ auto tableType = ph.parseTableType(ctx);
+ if (failed(tableType))
+ return failure();
+ LLVM_DEBUG(llvm::dbgs() << " Parsed table description: " << *tableType
+ << '\n');
+ auto symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
+ auto tableOp = builder.create<TableOp>(opLocation, symbol.strref(), *tableType);
+ symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)});
+ return success();
+}
+
template <>
LogicalResult
WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph,
diff --git a/mlir/test/Target/Wasm/inputs/table.yaml.wasm b/mlir/test/Target/Wasm/inputs/table.yaml.wasm
new file mode 100644
index 0000000000000..387f41820524f
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/table.yaml.wasm
@@ -0,0 +1,23 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TABLE
+ Tables:
+ - Index: 0
+ ElemType: FUNCREF
+ Limits:
+ Minimum: 0x2
+ - Index: 1
+ ElemType: FUNCREF
+ Limits:
+ Flags: [ HAS_MAX ]
+ Minimum: 0x2
+ Maximum: 0x4
+ - Index: 2
+ ElemType: EXTERNREF
+ Limits:
+ Flags: [ HAS_MAX ]
+ Minimum: 0x2
+ Maximum: 0x4
+...
diff --git a/mlir/test/Target/Wasm/stats.mlir b/mlir/test/Target/Wasm/stats.mlir
index e68b85d20f67d..dc30e95343d5a 100644
--- a/mlir/test/Target/Wasm/stats.mlir
+++ b/mlir/test/Target/Wasm/stats.mlir
@@ -16,4 +16,4 @@
// CHECK: 1 wasm-translate - Parsed functions
// CHECK-NEXT: 0 wasm-translate - Parsed globals
// CHECK-NEXT: 0 wasm-translate - Parsed memories
-// CHECK-NEXT: 0 wasm-translate - Parsed tables
+// CHECK-NEXT: 1 wasm-translate - Parsed tables
>From 9af8a9d48496e7d41f10136de16706774392f8b4 Mon Sep 17 00:00:00 2001
From: Luc Forget <dev at alias.lforget.fr>
Date: Mon, 30 Jun 2025 19:27:44 +0200
Subject: [PATCH 03/14] [mlir][wasm] Handle memory in Wasm importer
---------
Co-authored-by: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Co-authored-by: Jessica Paquette <jessica.paquette at woven-planet.global>
---
mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 20 +++++++++++++++++++
.../Wasm/inputs/memory_min_eq_max.yaml.wasm | 10 ++++++++++
.../Wasm/inputs/memory_min_max.yaml.wasm | 10 ++++++++++
.../Wasm/inputs/memory_min_no_max.yaml.wasm | 8 ++++++++
mlir/test/Target/Wasm/memory_min_eq_max.mlir | 7 +++++++
mlir/test/Target/Wasm/memory_min_max.mlir | 7 +++++++
mlir/test/Target/Wasm/memory_min_no_max.mlir | 7 +++++++
mlir/test/Target/Wasm/stats.mlir | 2 +-
8 files changed, 70 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm
create mode 100644 mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm
create mode 100644 mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm
create mode 100644 mlir/test/Target/Wasm/memory_min_eq_max.mlir
create mode 100644 mlir/test/Target/Wasm/memory_min_max.mlir
create mode 100644 mlir/test/Target/Wasm/memory_min_no_max.mlir
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 23f0cca1a148f..fe58b43d5d24d 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -705,6 +705,11 @@ class WasmBinaryParser {
if (failed(parsingTables))
return;
+ auto parsingMems = parseSection<WasmSectionType::MEMORY>();
+ if (failed(parsingMems))
+ return;
+
+
// Copy over sizes of containers into statistics.
numFunctionSectionItems = symbols.funcSymbols.size();
numGlobalSectionItems = symbols.globalSymbols.size();
@@ -798,6 +803,21 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
symbols.moduleFuncTypes.push_back(*funcType);
return success();
}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph, size_t) {
+ auto opLocation = ph.getLocation();
+ auto memory = ph.parseLimit(ctx);
+ if (failed(memory))
+ return failure();
+
+ LLVM_DEBUG(llvm::dbgs() << " Registering memory " << *memory << '\n');
+ auto symbol = symbols.getNewMemorySymbolName();
+ auto memOp = builder.create<MemOp>(opLocation, symbol, *memory);
+ symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)});
+ return success();
+}
} // namespace
namespace mlir {
diff --git a/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm
new file mode 100644
index 0000000000000..f3edf5f2d0cc2
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm
@@ -0,0 +1,10 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: MEMORY
+ Memories:
+ - Flags: [ HAS_MAX ]
+ Minimum: 0x0
+ Maximum: 0x0
+...
diff --git a/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm
new file mode 100644
index 0000000000000..fe70fb686df37
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm
@@ -0,0 +1,10 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: MEMORY
+ Memories:
+ - Flags: [ HAS_MAX ]
+ Minimum: 0x0
+ Maximum: 0x10000
+...
diff --git a/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm
new file mode 100644
index 0000000000000..8508ce38251a3
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm
@@ -0,0 +1,8 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: MEMORY
+ Memories:
+ - Minimum: 0x1
+...
diff --git a/mlir/test/Target/Wasm/memory_min_eq_max.mlir b/mlir/test/Target/Wasm/memory_min_eq_max.mlir
new file mode 100644
index 0000000000000..088e28685d09a
--- /dev/null
+++ b/mlir/test/Target/Wasm/memory_min_eq_max.mlir
@@ -0,0 +1,7 @@
+// RUN: yaml2obj %S/inputs/memory_min_eq_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module (memory 0 0))
+*/
+
+// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa<limit[0: 0]>, sym_name = "mem_0", sym_visibility = "nested"}> : () -> ()
diff --git a/mlir/test/Target/Wasm/memory_min_max.mlir b/mlir/test/Target/Wasm/memory_min_max.mlir
new file mode 100644
index 0000000000000..16d3468279d42
--- /dev/null
+++ b/mlir/test/Target/Wasm/memory_min_max.mlir
@@ -0,0 +1,7 @@
+// RUN: yaml2obj %S/inputs/memory_min_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module (memory 0 65536))
+*/
+
+// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa<limit[0: 65536]>, sym_name = "mem_0", sym_visibility = "nested"}> : () -> ()
diff --git a/mlir/test/Target/Wasm/memory_min_no_max.mlir b/mlir/test/Target/Wasm/memory_min_no_max.mlir
new file mode 100644
index 0000000000000..f71cb1098be18
--- /dev/null
+++ b/mlir/test/Target/Wasm/memory_min_no_max.mlir
@@ -0,0 +1,7 @@
+// RUN: yaml2obj %S/inputs/memory_min_no_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module (memory 1))
+*/
+
+// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa<limit[1:]>, sym_name = "mem_0", sym_visibility = "nested"}> : () -> ()
diff --git a/mlir/test/Target/Wasm/stats.mlir b/mlir/test/Target/Wasm/stats.mlir
index dc30e95343d5a..b361de3d99f31 100644
--- a/mlir/test/Target/Wasm/stats.mlir
+++ b/mlir/test/Target/Wasm/stats.mlir
@@ -15,5 +15,5 @@
// CHECK: 1 wasm-translate - Parsed functions
// CHECK-NEXT: 0 wasm-translate - Parsed globals
-// CHECK-NEXT: 0 wasm-translate - Parsed memories
+// CHECK-NEXT: 1 wasm-translate - Parsed memories
// CHECK-NEXT: 1 wasm-translate - Parsed tables
>From 35655b8ea01f068160b206f5a916dbf9406e7b40 Mon Sep 17 00:00:00 2001
From: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Date: Tue, 1 Jul 2025 13:42:45 +0900
Subject: [PATCH 04/14] [mlir][wasm] Handling of export at Wasm importer level
--
Co-authored-by: Luc Forget <luc.forget at woven-planet.global>
Co-authored-by: Jessica Paquette <jessica.paquette at woven-planet.global>
---
.../mlir/Target/Wasm/WasmBinaryEncoding.h | 9 +++
mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 74 +++++++++++++++++++
.../Wasm/function_export_out_of_scope.yaml | 15 ++++
3 files changed, 98 insertions(+)
create mode 100644 mlir/test/Target/Wasm/function_export_out_of_scope.yaml
diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
index e01193e47fdea..f4721d943fe81 100644
--- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
+++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
@@ -49,6 +49,15 @@ struct WasmBinaryEncoding {
static constexpr std::byte isMutable{0x01};
};
+ /// Byte encodings describing WASM exports.
+ struct Export {
+ static constexpr std::byte function{0x00};
+ static constexpr std::byte table{0x01};
+ static constexpr std::byte memory{0x02};
+ static constexpr std::byte global{0x03};
+ };
+
+ static constexpr std::byte endByte{0x0B};
};
} // namespace mlir
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index fe58b43d5d24d..1b0235e7d6f90 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -13,6 +13,7 @@
#include "mlir/Target/Wasm/WasmImporter.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/LEB128.h"
#include <variant>
@@ -709,6 +710,9 @@ class WasmBinaryParser {
if (failed(parsingMems))
return;
+ auto parsingExports = parseSection<WasmSectionType::EXPORT>();
+ if (failed(parsingExports))
+ return;
// Copy over sizes of containers into statistics.
numFunctionSectionItems = symbols.funcSymbols.size();
@@ -752,6 +756,76 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph, size
*import);
}
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
+ size_t) {
+ auto exportLoc = ph.getLocation();
+
+ auto exportName = ph.parseName();
+ if (failed(exportName))
+ return failure();
+
+ auto opcode = ph.consumeByte();
+ if (failed(opcode))
+ return failure();
+
+ auto idx = ph.parseLiteral<uint32_t>();
+ if (failed(idx))
+ return failure();
+
+ using SymbolRefDesc =
+ std::variant<llvm::SmallVector<SymbolRefContainer>,
+ llvm::SmallVector<GlobalSymbolRefContainer>,
+ llvm::SmallVector<FunctionSymbolRefContainer>>;
+
+ SymbolRefDesc currentSymbolList;
+ std::string symbolType = "";
+ switch (*opcode) {
+ case WasmBinaryEncoding::Export::function:
+ symbolType = "function";
+ currentSymbolList = symbols.funcSymbols;
+ break;
+ case WasmBinaryEncoding::Export::table:
+ symbolType = "table";
+ currentSymbolList = symbols.tableSymbols;
+ break;
+ case WasmBinaryEncoding::Export::memory:
+ symbolType = "memory";
+ currentSymbolList = symbols.memSymbols;
+ break;
+ case WasmBinaryEncoding::Export::global:
+ symbolType = "global";
+ currentSymbolList = symbols.globalSymbols;
+ break;
+ default:
+ return emitError(exportLoc, "Invalid value for export type: ")
+ << std::to_integer<unsigned>(*opcode);
+ }
+
+ auto currentSymbol = std::visit(
+ [&](const auto &list) -> FailureOr<FlatSymbolRefAttr> {
+ if (*idx > list.size()) {
+ emitError(
+ exportLoc,
+ llvm::formatv(
+ "Trying to export {0} {1} which is undefined in this scope",
+ symbolType, *idx));
+ return failure();
+ }
+ return list[*idx].symbol;
+ },
+ currentSymbolList);
+
+ if (failed(currentSymbol))
+ return failure();
+
+ Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol);
+ SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public);
+ auto symName = SymbolTable::getSymbolName(op);
+ return SymbolTable{mOp}.rename(symName, *exportName);
+}
+
template <>
LogicalResult
WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph, size_t) {
diff --git a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml
new file mode 100644
index 0000000000000..ffb26f563141a
--- /dev/null
+++ b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml
@@ -0,0 +1,15 @@
+# RUN: yaml2obj %s | mlir-translate --import-wasm -o - 2>&1 | FileCheck %s
+
+# FIXME: The error code here should be nonzero.
+
+# CHECK: Trying to export function 42 which is undefined in this scope
+
+--- !WASM
+FileHeader:
+ Version: 0x00000001
+Sections:
+ - Type: EXPORT
+ Exports:
+ - Name: function_export
+ Kind: FUNCTION
+ Index: 42
>From 1ededd4183e903675aebf09fca3405f7c216e213 Mon Sep 17 00:00:00 2001
From: Luc Forget <dev at alias.lforget.fr>
Date: Thu, 3 Jul 2025 09:49:59 +0900
Subject: [PATCH 05/14] [mlir][wasm] Expression parsing mechanism for Wasm
importer
---------
Co-authored-by: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Co-authored-by: Jessica Paquette <jessica.paquette at woven-planet.global>
---
.../mlir/Target/Wasm/WasmBinaryEncoding.h | 9 +
mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 333 ++++++++++++++++++
2 files changed, 342 insertions(+)
diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
index f4721d943fe81..a5b124eecbe67 100644
--- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
+++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
@@ -16,6 +16,15 @@
#include <cstddef>
namespace mlir {
struct WasmBinaryEncoding {
+ /// Byte encodings for WASM instructions.
+ struct OpCode {
+ // Locals, globals, constants.
+ static constexpr std::byte constI32{0x41};
+ static constexpr std::byte constI64{0x42};
+ static constexpr std::byte constFP32{0x43};
+ static constexpr std::byte constFP64{0x44};
+ };
+
/// Byte encodings of types in WASM binaries
struct Type {
static constexpr std::byte emptyBlockType{0x40};
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 1b0235e7d6f90..753513eb9d887 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -142,6 +142,8 @@ struct FunctionSymbolRefContainer : SymbolRefContainer {
using ImportDesc = std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
+using parsed_inst_t = llvm::FailureOr<llvm::SmallVector<Value>>;
+
struct WasmModuleSymbolTables {
llvm::SmallVector<FunctionSymbolRefContainer> funcSymbols;
llvm::SmallVector<GlobalSymbolRefContainer> globalSymbols;
@@ -173,6 +175,134 @@ struct WasmModuleSymbolTables {
return getNewSymbolName("table_", id);
}
};
+
+class ParserHead;
+
+/// Wrapper around SmallVector to only allow access as push and pop on the
+/// stack. Makes sure that there are no "free accesses" on the stack to preserve
+/// its state.
+class ValueStack {
+private:
+ struct LabelLevel {
+ size_t stackIdx;
+ LabelLevelOpInterface levelOp;
+ };
+public:
+ bool empty() const { return values.empty(); }
+
+ size_t size() const { return values.size(); }
+
+ /// Pops values from the stack because they are being used in an operation.
+ /// @param operandTypes The list of expected types of the operation, used
+ /// to know how many values to pop and check if the types match the
+ /// expectation.
+ /// @param opLoc Location of the caller, used to report accurately the
+ /// location
+ /// if an error occurs.
+ /// @return Failure or the vector of popped values.
+ llvm::FailureOr<llvm::SmallVector<Value>> popOperands(TypeRange operandTypes,
+ Location *opLoc);
+
+ /// Push the results of an operation to the stack so they can be used in a
+ /// following operation.
+ /// @param results The list of results of the operation
+ /// @param opLoc Location of the caller, used to report accurately the
+ /// location
+ /// if an error occurs.
+ LogicalResult pushResults(ValueRange results, Location *opLoc);
+
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ /// A simple dump function for debugging.
+ /// Writes output to llvm::dbgs().
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+
+private:
+ llvm::SmallVector<Value> values;
+};
+
+using local_val_t = TypedValue<wasmssa::LocalRefType>;
+
+class ExpressionParser {
+public:
+ using locals_t = llvm::SmallVector<local_val_t>;
+ ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols,
+ llvm::ArrayRef<local_val_t> initLocal)
+ : parser{parser}, symbols{symbols}, locals{initLocal} {}
+
+private:
+ template <std::byte opCode>
+ inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder);
+
+ template <typename valueT>
+ parsed_inst_t
+ parseConstInst(OpBuilder &builder,
+ std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr);
+
+
+ /// This function generates a dispatch tree to associate an opcode with a
+ /// parser. Parsers are registered by specialising the
+ /// `parseSpecificInstruction` function for the op code to handle.
+ ///
+ /// The dispatcher is generated by recursively creating all possible patterns
+ /// for an opcode and calling the relevant parser on the leaf.
+ ///
+ /// @tparam patternBitSize is the first bit for which the pattern is not fixed
+ ///
+ /// @tparam highBitPattern is the fixed pattern that this instance handles for
+ /// the 8-patternBitSize bits
+ template <size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}>
+ inline parsed_inst_t dispatchToInstParser(std::byte opCode,
+ OpBuilder &builder) {
+ static_assert(patternBitSize <= 8,
+ "PatternBitSize is outside of range of opcode space! "
+ "(expected at most 8 bits)");
+ if constexpr (patternBitSize < 8) {
+ constexpr std::byte bitSelect{1 << (7 - patternBitSize)};
+ constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
+ constexpr size_t nextPatternBitSize = patternBitSize + 1;
+ if ((opCode & bitSelect) != std::byte{0})
+ return dispatchToInstParser < nextPatternBitSize,
+ nextHighBitPatternStem | std::byte{1} > (opCode, builder);
+ return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
+ opCode, builder);
+ } else {
+ return parseSpecificInstruction<highBitPattern>(builder);
+ }
+ }
+
+ struct ParseResultWithInfo {
+ llvm::SmallVector<Value> opResults;
+ std::byte endingByte;
+ };
+
+public:
+ template<std::byte ParseEndByte = WasmBinaryEncoding::endByte>
+ parsed_inst_t parse(OpBuilder &builder,
+ UniqueByte<ParseEndByte> = {});
+
+ template <std::byte... ExpressionParseEnd>
+ llvm::FailureOr<ParseResultWithInfo>
+ parse(OpBuilder &builder,
+ ByteSequence<ExpressionParseEnd...> parsingEndFilters);
+
+ llvm::FailureOr<llvm::SmallVector<Value>>
+ popOperands(TypeRange operandTypes) {
+ return valueStack.popOperands(operandTypes, ¤tOpLoc.value());
+ }
+
+ LogicalResult pushResults(ValueRange results) {
+ return valueStack.pushResults(results, ¤tOpLoc.value());
+ }
+private:
+ std::optional<Location> currentOpLoc;
+ ParserHead &parser;
+ WasmModuleSymbolTables const &symbols;
+ locals_t locals;
+ ValueStack valueStack;
+ };
+
class ParserHead {
public:
ParserHead(llvm::StringRef src, StringAttr name) : head{src}, locName{name} {}
@@ -382,6 +512,14 @@ class ParserHead {
<< static_cast<int>(*importType);
}
}
+
+ parsed_inst_t parseExpression(OpBuilder &builder,
+ WasmModuleSymbolTables const &symbols,
+ llvm::ArrayRef<local_val_t> locals = {}) {
+ auto eParser = ExpressionParser{*this, symbols, locals};
+ return eParser.parse(builder);
+ }
+
bool end() const { return curHead().empty(); }
ParserHead copy() const {
@@ -491,6 +629,201 @@ inline llvm::FailureOr<int64_t> ParserHead::parseI64() {
return parseLiteral<int64_t>();
}
+template <std::byte opCode>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
+ return emitError(*currentOpLoc, "Unknown instruction opcode: ")
+ << static_cast<int>(opCode);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void ValueStack::dump() const {
+ llvm::dbgs() << "================= Wasm ValueStack =======================\n";
+ llvm::dbgs() << "size: " << size() << "\n";
+ llvm::dbgs() << "<Top>"
+ << "\n";
+ // Stack is pushed to via push_back. Therefore the top of the stack is the
+ // end of the vector. Iterate in reverse so that the first thing we print
+ // is the top of the stack.
+ size_t stackSize = size();
+ for (size_t idx = 0 ; idx < stackSize ;) {
+ size_t actualIdx = stackSize - 1 - idx;
+ llvm::dbgs() << " ";
+ values[actualIdx].dump();
+ }
+ llvm::dbgs() << "<Bottom>"
+ << "\n";
+ llvm::dbgs() << "=========================================================\n";
+}
+#endif
+
+parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
+ LLVM_DEBUG(llvm::dbgs() << "Popping from ValueStack\n");
+ LLVM_DEBUG(llvm::dbgs() << " Elements(s) to pop: " << operandTypes.size()
+ << "\n");
+ LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n");
+ if (operandTypes.size() > values.size())
+ return emitError(*opLoc,
+ "Stack doesn't contain enough values. Trying to get ")
+ << operandTypes.size() << " operands on a stack containing only "
+ << values.size() << " values.";
+ size_t stackIdxOffset = values.size() - operandTypes.size();
+ llvm::SmallVector<Value> res{};
+ res.reserve(operandTypes.size());
+ for (size_t i{0}; i < operandTypes.size(); ++i) {
+ Value operand = values[i + stackIdxOffset];
+ Type stackType = operand.getType();
+ if (stackType != operandTypes[i])
+ return emitError(*opLoc,
+ "Invalid operand type on stack. Expecting ")
+ << operandTypes[i] << ", value on stack is of type " << stackType
+ << ".";
+ LLVM_DEBUG(llvm::dbgs() << " POP: " << operand << "\n");
+ res.push_back(operand);
+ }
+ values.resize(values.size() - operandTypes.size());
+ LLVM_DEBUG(llvm::dbgs() << " Updated stack size: " << values.size() << "\n");
+ return res;
+}
+
+LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
+ LLVM_DEBUG(llvm::dbgs() << "Pushing to ValueStack\n");
+ LLVM_DEBUG(llvm::dbgs() << " Elements(s) to push: " << results.size()
+ << "\n");
+ LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n");
+ for (auto val : results) {
+ if (!isWasmValueType(val.getType()))
+ return emitError(*opLoc, "Invalid value type on stack: ")
+ << val.getType();
+ LLVM_DEBUG(llvm::dbgs() << " PUSH: " << val << "\n");
+ values.push_back(val);
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " Updated stack size: " << values.size() << "\n");
+ return success();
+}
+
+template<std::byte EndParseByte>
+parsed_inst_t ExpressionParser::parse(OpBuilder &builder, UniqueByte<EndParseByte> endByte) {
+ auto res = parse(builder, ByteSequence<EndParseByte>{});
+ if (failed(res))
+ return failure();
+ return res->opResults;
+}
+
+template <std::byte... ExpressionParseEnd>
+llvm::FailureOr<ExpressionParser::ParseResultWithInfo>
+ExpressionParser::parse(OpBuilder &builder,
+ ByteSequence<ExpressionParseEnd...> parsingEndFilters) {
+ llvm::SmallVector<Value> res;
+ for (;;) {
+ currentOpLoc = parser.getLocation();
+ auto opCode = parser.consumeByte();
+ if (failed(opCode))
+ return failure();
+ if (isValueOneOf(*opCode, parsingEndFilters))
+ return {{res, *opCode}};
+ parsed_inst_t resParsed;
+ resParsed = dispatchToInstParser(*opCode, builder);
+ if (failed(resParsed))
+ return failure();
+ std::swap(res, *resParsed);
+ if (failed(pushResults(res)))
+ return failure();
+ }
+}
+
+
+template <typename T>
+inline Type buildLiteralType(OpBuilder &);
+
+template <>
+inline Type buildLiteralType<int32_t>(OpBuilder &builder) {
+ return builder.getI32Type();
+}
+
+template <>
+inline Type buildLiteralType<int64_t>(OpBuilder &builder) {
+ return builder.getI64Type();
+}
+
+template <>
+inline Type buildLiteralType<uint32_t>(OpBuilder &builder) {
+ return builder.getI32Type();
+}
+
+template <>
+inline Type buildLiteralType<uint64_t>(OpBuilder &builder) {
+ return builder.getI64Type();
+}
+
+template <>
+inline Type buildLiteralType<float>(OpBuilder &builder) {
+ return builder.getF32Type();
+}
+
+template <>
+inline Type buildLiteralType<double>(OpBuilder &builder) {
+ return builder.getF64Type();
+}
+
+template<typename ValT, typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
+struct AttrHolder;
+
+template <typename ValT>
+struct AttrHolder<ValT, std::enable_if_t<std::is_integral_v<ValT>>> {
+ using type = IntegerAttr;
+};
+
+template <typename ValT>
+struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
+ using type = FloatAttr;
+};
+
+template<typename ValT>
+using attr_holder_t = typename AttrHolder<ValT>::type;
+
+template <typename ValT,
+ typename EnableT = std::enable_if_t<std::is_arithmetic_v<ValT>>>
+attr_holder_t<ValT> buildLiteralAttr(OpBuilder &builder, ValT val) {
+ return attr_holder_t<ValT>::get(buildLiteralType<ValT>(builder), val);
+}
+
+template <typename valueT>
+parsed_inst_t ExpressionParser::parseConstInst(
+ OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueT>> *) {
+ auto parsedConstant = parser.parseLiteral<valueT>();
+ if (failed(parsedConstant))
+ return failure();
+ auto constOp = builder.create<ConstOp>(
+ *currentOpLoc, buildLiteralAttr<valueT>(builder, *parsedConstant));
+ return {{constOp.getResult()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) {
+ return parseConstInst<int32_t>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) {
+ return parseConstInst<int64_t>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) {
+ return parseConstInst<float>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) {
+ return parseConstInst<double>(builder);
+}
+
+
class WasmBinaryParser {
private:
struct SectionRegistry {
>From 2cb8732c4a731d3ddee528bf15e383c810f982ed Mon Sep 17 00:00:00 2001
From: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Date: Thu, 31 Jul 2025 13:45:21 +0900
Subject: [PATCH 06/14] Formatting and other comments from the previous PR
---
mlir/include/mlir/InitAllTranslations.h | 1 +
mlir/include/mlir/Target/Wasm/WasmImporter.h | 9 +-
mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 428 +++++++++---------
.../lib/Target/Wasm/TranslateRegistration.cpp | 16 +-
mlir/test/Target/Wasm/bad_wasm_version.yaml | 2 +-
.../Wasm/function_export_out_of_scope.yaml | 2 +-
.../Wasm/invalid_function_type_index.yaml | 2 +-
mlir/test/Target/Wasm/missing_header.yaml | 2 +-
8 files changed, 232 insertions(+), 230 deletions(-)
diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h
index cf8f108b88159..622024db5a8a2 100644
--- a/mlir/include/mlir/InitAllTranslations.h
+++ b/mlir/include/mlir/InitAllTranslations.h
@@ -17,6 +17,7 @@
#include "mlir/Target/IRDLToCpp/TranslationRegistration.h"
namespace mlir {
+
void registerFromLLVMIRTranslation();
void registerFromSPIRVTranslation();
void registerFromWasmTranslation();
diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h
index fc7d275353964..5cc42a1f32fa4 100644
--- a/mlir/include/mlir/Target/Wasm/WasmImporter.h
+++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h
@@ -19,8 +19,7 @@
#include "mlir/IR/OwningOpRef.h"
#include "llvm/Support/SourceMgr.h"
-namespace mlir {
-namespace wasm {
+namespace mlir::wasm {
/// Translates the given operation to C++ code. The operation or operations in
/// the region of 'op' need almost all be in EmitC dialect. The parameter
@@ -28,8 +27,8 @@ namespace wasm {
/// arguments are declared at the beginning of the function.
/// If parameter 'fileId' is non-empty, then body of `emitc.file` ops
/// with matching id are emitted.
-OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext* context);
-} // namespace wasm
-} // namespace mlir
+OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source,
+ MLIRContext *context);
+} // namespace mlir::wasm
#endif // MLIR_TARGET_WASM_WASMIMPORTER_H
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 753513eb9d887..dd8b86670c31a 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -9,6 +9,8 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
#include "mlir/Target/Wasm/WasmBinaryEncoding.h"
#include "mlir/Target/Wasm/WasmImporter.h"
#include "llvm/ADT/Statistic.h"
@@ -16,6 +18,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/LEB128.h"
+#include <cstdint>
#include <variant>
#define DEBUG_TYPE "wasm-translate"
@@ -26,7 +29,8 @@ STATISTIC(numGlobalSectionItems, "Parsed globals");
STATISTIC(numMemorySectionItems, "Parsed memories");
STATISTIC(numTableSectionItems, "Parsed tables");
-static_assert(CHAR_BIT == 8, "This code expects std::byte to be exactly 8 bits");
+static_assert(CHAR_BIT == 8,
+ "This code expects std::byte to be exactly 8 bits");
using namespace mlir;
using namespace mlir::wasm;
@@ -51,7 +55,7 @@ enum struct WasmSectionType : section_id_t {
};
constexpr section_id_t highestWasmSectionID{
- static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
+ static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
#define APPLY_WASM_SEC_TRANSFORM \
WASM_SEC_TRANSFORM(CUSTOM) \
@@ -82,7 +86,7 @@ constexpr bool sectionShouldBeUnique(WasmSectionType secType) {
}
template <std::byte... Bytes>
-struct ByteSequence{};
+struct ByteSequence {};
template <std::byte... Bytes1, std::byte... Bytes2>
constexpr ByteSequence<Bytes1..., Bytes2...>
@@ -91,7 +95,7 @@ operator+(ByteSequence<Bytes1...>, ByteSequence<Bytes2...>) {
}
/// Template class for representing a byte sequence of only one byte
-template<std::byte Byte>
+template <std::byte Byte>
struct UniqueByte : ByteSequence<Byte> {};
template <typename T, T... Values>
@@ -109,12 +113,13 @@ constexpr ByteSequence<
WasmBinaryEncoding::Type::v128>
valueTypesEncodings{};
-template<std::byte... allowedFlags>
-constexpr bool isValueOneOf(std::byte value, ByteSequence<allowedFlags...> = {}) {
- return ((value == allowedFlags) | ... | false);
+template <std::byte... allowedFlags>
+constexpr bool isValueOneOf(std::byte value,
+ ByteSequence<allowedFlags...> = {}) {
+ return ((value == allowedFlags) | ... | false);
}
-template<std::byte... flags>
+template <std::byte... flags>
constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) {
return !isValueOneOf<flags...>(value);
}
@@ -140,19 +145,20 @@ struct FunctionSymbolRefContainer : SymbolRefContainer {
FunctionType functionType;
};
-using ImportDesc = std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
+using ImportDesc =
+ std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
-using parsed_inst_t = llvm::FailureOr<llvm::SmallVector<Value>>;
+using parsed_inst_t = FailureOr<SmallVector<Value>>;
struct WasmModuleSymbolTables {
- llvm::SmallVector<FunctionSymbolRefContainer> funcSymbols;
- llvm::SmallVector<GlobalSymbolRefContainer> globalSymbols;
- llvm::SmallVector<SymbolRefContainer> memSymbols;
- llvm::SmallVector<SymbolRefContainer> tableSymbols;
- llvm::SmallVector<FunctionType> moduleFuncTypes;
-
- std::string getNewSymbolName(llvm::StringRef prefix, size_t id) const {
- return (prefix + llvm::Twine{id}).str();
+ SmallVector<FunctionSymbolRefContainer> funcSymbols;
+ SmallVector<GlobalSymbolRefContainer> globalSymbols;
+ SmallVector<SymbolRefContainer> memSymbols;
+ SmallVector<SymbolRefContainer> tableSymbols;
+ SmallVector<FunctionType> moduleFuncTypes;
+
+ std::string getNewSymbolName(StringRef prefix, size_t id) const {
+ return (prefix + Twine{id}).str();
}
std::string getNewFuncSymbolName() const {
@@ -187,6 +193,7 @@ class ValueStack {
size_t stackIdx;
LabelLevelOpInterface levelOp;
};
+
public:
bool empty() const { return values.empty(); }
@@ -200,8 +207,8 @@ class ValueStack {
/// location
/// if an error occurs.
/// @return Failure or the vector of popped values.
- llvm::FailureOr<llvm::SmallVector<Value>> popOperands(TypeRange operandTypes,
- Location *opLoc);
+ FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes,
+ Location *opLoc);
/// Push the results of an operation to the stack so they can be used in a
/// following operation.
@@ -211,7 +218,6 @@ class ValueStack {
/// if an error occurs.
LogicalResult pushResults(ValueRange results, Location *opLoc);
-
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// A simple dump function for debugging.
/// Writes output to llvm::dbgs().
@@ -219,16 +225,16 @@ class ValueStack {
#endif
private:
- llvm::SmallVector<Value> values;
+ SmallVector<Value> values;
};
using local_val_t = TypedValue<wasmssa::LocalRefType>;
class ExpressionParser {
public:
- using locals_t = llvm::SmallVector<local_val_t>;
+ using locals_t = SmallVector<local_val_t>;
ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols,
- llvm::ArrayRef<local_val_t> initLocal)
+ ArrayRef<local_val_t> initLocal)
: parser{parser}, symbols{symbols}, locals{initLocal} {}
private:
@@ -240,7 +246,6 @@ class ExpressionParser {
parseConstInst(OpBuilder &builder,
std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr);
-
/// This function generates a dispatch tree to associate an opcode with a
/// parser. Parsers are registered by specialising the
/// `parseSpecificInstruction` function for the op code to handle.
@@ -263,8 +268,9 @@ class ExpressionParser {
constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
constexpr size_t nextPatternBitSize = patternBitSize + 1;
if ((opCode & bitSelect) != std::byte{0})
- return dispatchToInstParser < nextPatternBitSize,
- nextHighBitPatternStem | std::byte{1} > (opCode, builder);
+ return dispatchToInstParser<nextPatternBitSize,
+ nextHighBitPatternStem | std::byte{1}>(
+ opCode, builder);
return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
opCode, builder);
} else {
@@ -273,40 +279,40 @@ class ExpressionParser {
}
struct ParseResultWithInfo {
- llvm::SmallVector<Value> opResults;
+ SmallVector<Value> opResults;
std::byte endingByte;
};
public:
- template<std::byte ParseEndByte = WasmBinaryEncoding::endByte>
- parsed_inst_t parse(OpBuilder &builder,
- UniqueByte<ParseEndByte> = {});
+ template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
+ parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
template <std::byte... ExpressionParseEnd>
- llvm::FailureOr<ParseResultWithInfo>
+ FailureOr<ParseResultWithInfo>
parse(OpBuilder &builder,
ByteSequence<ExpressionParseEnd...> parsingEndFilters);
- llvm::FailureOr<llvm::SmallVector<Value>>
- popOperands(TypeRange operandTypes) {
+ FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) {
return valueStack.popOperands(operandTypes, ¤tOpLoc.value());
}
LogicalResult pushResults(ValueRange results) {
return valueStack.pushResults(results, ¤tOpLoc.value());
}
+
private:
std::optional<Location> currentOpLoc;
ParserHead &parser;
WasmModuleSymbolTables const &symbols;
locals_t locals;
ValueStack valueStack;
- };
+};
class ParserHead {
public:
- ParserHead(llvm::StringRef src, StringAttr name) : head{src}, locName{name} {}
+ ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {}
ParserHead(ParserHead &&) = default;
+
private:
ParserHead(ParserHead const &other) = default;
@@ -315,7 +321,7 @@ class ParserHead {
return FileLineColLoc::get(locName, 0, anchorOffset + offset);
}
- llvm::FailureOr<llvm::StringRef> consumeNBytes(size_t nBytes) {
+ FailureOr<StringRef> consumeNBytes(size_t nBytes) {
LLVM_DEBUG(llvm::dbgs() << "Consume " << nBytes << " bytes\n");
LLVM_DEBUG(llvm::dbgs() << " Bytes remaining: " << size() << "\n");
LLVM_DEBUG(llvm::dbgs() << " Current offset: " << offset << "\n");
@@ -323,14 +329,14 @@ class ParserHead {
return emitError(getLocation(), "trying to extract ")
<< nBytes << "bytes when only " << size() << "are avilables";
- auto res = head.slice(offset, offset + nBytes);
+ StringRef res = head.slice(offset, offset + nBytes);
offset += nBytes;
LLVM_DEBUG(llvm::dbgs()
<< " Updated offset (+" << nBytes << "): " << offset << "\n");
return res;
}
- llvm::FailureOr<std::byte> consumeByte() {
+ FailureOr<std::byte> consumeByte() {
auto res = consumeNBytes(1);
if (failed(res))
return failure();
@@ -338,52 +344,52 @@ class ParserHead {
}
template <typename T>
- llvm::FailureOr<T> parseLiteral();
+ FailureOr<T> parseLiteral();
- llvm::FailureOr<uint32_t> parseVectorSize();
+ FailureOr<uint32_t> parseVectorSize();
private:
// TODO: This is equivalent to parseLiteral<uint32_t> and could be removed
// if parseLiteral specialization were moved here, but default GCC on Ubuntu
// 22.04 has bug with template specialization in class declaration
- inline llvm::FailureOr<uint32_t> parseUI32();
- inline llvm::FailureOr<int64_t> parseI64();
+ inline FailureOr<uint32_t> parseUI32();
+ inline FailureOr<int64_t> parseI64();
public:
- llvm::FailureOr<llvm::StringRef> parseName() {
- auto size = parseVectorSize();
+ FailureOr<StringRef> parseName() {
+ FailureOr<uint32_t> size = parseVectorSize();
if (failed(size))
return failure();
return consumeNBytes(*size);
}
- llvm::FailureOr<WasmSectionType> parseWasmSectionType() {
- auto id = consumeByte();
+ FailureOr<WasmSectionType> parseWasmSectionType() {
+ FailureOr<std::byte> id = consumeByte();
if (failed(id))
return failure();
if (std::to_integer<unsigned>(*id) > highestWasmSectionID)
- return emitError(getLocation(), "Invalid section ID: ")
+ return emitError(getLocation(), "invalid section ID: ")
<< static_cast<int>(*id);
return static_cast<WasmSectionType>(*id);
}
- llvm::FailureOr<LimitType> parseLimit(MLIRContext *ctx) {
+ FailureOr<LimitType> parseLimit(MLIRContext *ctx) {
using WasmLimits = WasmBinaryEncoding::LimitHeader;
- auto limitLocation = getLocation();
- auto limitHeader = consumeByte();
+ FileLineColLoc limitLocation = getLocation();
+ FailureOr<std::byte> limitHeader = consumeByte();
if (failed(limitHeader))
return failure();
if (isNotIn<WasmLimits::bothLimits, WasmLimits::lowLimitOnly>(*limitHeader))
- return emitError(limitLocation, "Invalid limit header: ")
+ return emitError(limitLocation, "invalid limit header: ")
<< static_cast<int>(*limitHeader);
- auto minParse = parseUI32();
+ FailureOr<uint32_t> minParse = parseUI32();
if (failed(minParse))
return failure();
std::optional<uint32_t> max{std::nullopt};
if (*limitHeader == WasmLimits::bothLimits) {
- auto maxParse = parseUI32();
+ FailureOr<uint32_t> maxParse = parseUI32();
if (failed(maxParse))
return failure();
max = *maxParse;
@@ -391,9 +397,9 @@ class ParserHead {
return LimitType::get(ctx, *minParse, max);
}
- llvm::FailureOr<Type> parseValueType(MLIRContext *ctx) {
- auto typeLoc = getLocation();
- auto typeEncoding = consumeByte();
+ FailureOr<Type> parseValueType(MLIRContext *ctx) {
+ FileLineColLoc typeLoc = getLocation();
+ FailureOr<std::byte> typeEncoding = consumeByte();
if (failed(typeEncoding))
return failure();
switch (*typeEncoding) {
@@ -412,35 +418,35 @@ class ParserHead {
case WasmBinaryEncoding::Type::externRef:
return wasmssa::ExternRefType::get(ctx);
default:
- return emitError(typeLoc, "Invalid value type encoding: ")
+ return emitError(typeLoc, "invalid value type encoding: ")
<< static_cast<int>(*typeEncoding);
}
}
- llvm::FailureOr<GlobalTypeRecord> parseGlobalType(MLIRContext *ctx) {
+ FailureOr<GlobalTypeRecord> parseGlobalType(MLIRContext *ctx) {
using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability;
- auto typeParsed = parseValueType(ctx);
+ FailureOr<Type> typeParsed = parseValueType(ctx);
if (failed(typeParsed))
return failure();
- auto mutLoc = getLocation();
- auto mutSpec = consumeByte();
+ FileLineColLoc mutLoc = getLocation();
+ FailureOr<std::byte> mutSpec = consumeByte();
if (failed(mutSpec))
return failure();
if (isNotIn<WasmGlobalMut::isConst, WasmGlobalMut::isMutable>(*mutSpec))
- return emitError(mutLoc, "Invalid global mutability specifier: ")
+ return emitError(mutLoc, "invalid global mutability specifier: ")
<< static_cast<int>(*mutSpec);
return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable};
}
- llvm::FailureOr<TupleType> parseResultType(MLIRContext *ctx) {
- auto nParamsParsed = parseVectorSize();
+ FailureOr<TupleType> parseResultType(MLIRContext *ctx) {
+ FailureOr<uint32_t> nParamsParsed = parseVectorSize();
if (failed(nParamsParsed))
return failure();
- auto nParams = *nParamsParsed;
- llvm::SmallVector<Type> res{};
+ uint32_t nParams = *nParamsParsed;
+ SmallVector<Type> res{};
res.reserve(nParams);
for (size_t i = 0; i < nParams; ++i) {
- auto parsedType = parseValueType(ctx);
+ FailureOr<Type> parsedType = parseValueType(ctx);
if (failed(parsedType))
return failure();
res.push_back(*parsedType);
@@ -448,50 +454,49 @@ class ParserHead {
return TupleType::get(ctx, res);
}
- llvm::FailureOr<FunctionType> parseFunctionType(MLIRContext *ctx) {
- auto typeLoc = getLocation();
- auto funcTypeHeader = consumeByte();
+ FailureOr<FunctionType> parseFunctionType(MLIRContext *ctx) {
+ FileLineColLoc typeLoc = getLocation();
+ FailureOr<std::byte> funcTypeHeader = consumeByte();
if (failed(funcTypeHeader))
return failure();
if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType)
- return emitError(typeLoc, "Invalid function type header byte. Expecting ")
- << std::to_integer<unsigned>(
- WasmBinaryEncoding::Type::funcType)
+ return emitError(typeLoc, "invalid function type header byte. Expecting ")
+ << std::to_integer<unsigned>(WasmBinaryEncoding::Type::funcType)
<< " got " << std::to_integer<unsigned>(*funcTypeHeader);
- auto inputTypes = parseResultType(ctx);
+ FailureOr<TupleType> inputTypes = parseResultType(ctx);
if (failed(inputTypes))
return failure();
- auto resTypes = parseResultType(ctx);
+ FailureOr<TupleType> resTypes = parseResultType(ctx);
if (failed(resTypes))
return failure();
return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes());
}
- llvm::FailureOr<TypeIdxRecord> parseTypeIndex() {
- auto res = parseUI32();
+ FailureOr<TypeIdxRecord> parseTypeIndex() {
+ FailureOr<uint32_t> res = parseUI32();
if (failed(res))
return failure();
return TypeIdxRecord{*res};
}
- llvm::FailureOr<TableType> parseTableType(MLIRContext *ctx) {
- auto elmTypeParse = parseValueType(ctx);
+ FailureOr<TableType> parseTableType(MLIRContext *ctx) {
+ FailureOr<Type> elmTypeParse = parseValueType(ctx);
if (failed(elmTypeParse))
return failure();
if (!isWasmRefType(*elmTypeParse))
- return emitError(getLocation(), "Invalid element type for table");
- auto limitParse = parseLimit(ctx);
+ return emitError(getLocation(), "invalid element type for table");
+ FailureOr<LimitType> limitParse = parseLimit(ctx);
if (failed(limitParse))
return failure();
return TableType::get(ctx, *elmTypeParse, *limitParse);
}
- llvm::FailureOr<ImportDesc> parseImportDesc(MLIRContext *ctx) {
- auto importLoc = getLocation();
- auto importType = consumeByte();
- auto packager = [](auto parseResult) -> llvm::FailureOr<ImportDesc> {
+ FailureOr<ImportDesc> parseImportDesc(MLIRContext *ctx) {
+ FileLineColLoc importLoc = getLocation();
+ FailureOr<std::byte> importType = consumeByte();
+ auto packager = [](auto parseResult) -> FailureOr<ImportDesc> {
if (llvm::failed(parseResult))
return failure();
return {*parseResult};
@@ -508,28 +513,26 @@ class ParserHead {
case WasmBinaryEncoding::Import::globalType:
return packager(parseGlobalType(ctx));
default:
- return emitError(importLoc, "Invalid import type descriptor: ")
+ return emitError(importLoc, "invalid import type descriptor: ")
<< static_cast<int>(*importType);
}
}
parsed_inst_t parseExpression(OpBuilder &builder,
WasmModuleSymbolTables const &symbols,
- llvm::ArrayRef<local_val_t> locals = {}) {
+ ArrayRef<local_val_t> locals = {}) {
auto eParser = ExpressionParser{*this, symbols, locals};
return eParser.parse(builder);
}
bool end() const { return curHead().empty(); }
- ParserHead copy() const {
- return *this;
- }
+ ParserHead copy() const { return *this; }
private:
- llvm::StringRef curHead() const { return head.drop_front(offset); }
+ StringRef curHead() const { return head.drop_front(offset); }
- llvm::FailureOr<std::byte> peek() const {
+ FailureOr<std::byte> peek() const {
if (end())
return emitError(
getLocation(),
@@ -539,14 +542,14 @@ class ParserHead {
size_t size() const { return head.size() - offset; }
- llvm::StringRef head;
+ StringRef head;
StringAttr locName;
unsigned anchorOffset{0};
unsigned offset{0};
};
template <>
-llvm::FailureOr<float> ParserHead::parseLiteral<float>() {
+FailureOr<float> ParserHead::parseLiteral<float>() {
auto bytes = consumeNBytes(4);
if (failed(bytes))
return failure();
@@ -556,7 +559,7 @@ llvm::FailureOr<float> ParserHead::parseLiteral<float>() {
}
template <>
-llvm::FailureOr<double> ParserHead::parseLiteral<double>() {
+FailureOr<double> ParserHead::parseLiteral<double>() {
auto bytes = consumeNBytes(8);
if (failed(bytes))
return failure();
@@ -566,13 +569,13 @@ llvm::FailureOr<double> ParserHead::parseLiteral<double>() {
}
template <>
-llvm::FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
+FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
char const *error = nullptr;
uint32_t res{0};
unsigned encodingSize{0};
- auto src = curHead();
- auto decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
- src.bytes_end(), &error);
+ StringRef src = curHead();
+ uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
if (error)
return emitError(getLocation(), error);
@@ -585,13 +588,13 @@ llvm::FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
}
template <>
-llvm::FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
+FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
char const *error = nullptr;
int32_t res{0};
unsigned encodingSize{0};
- auto src = curHead();
- auto decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
- src.bytes_end(), &error);
+ StringRef src = curHead();
+ int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
if (error)
return emitError(getLocation(), error);
if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) ||
@@ -604,12 +607,12 @@ llvm::FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
}
template <>
-llvm::FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
+FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
char const *error = nullptr;
unsigned encodingSize{0};
- auto src = curHead();
- auto res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
- src.bytes_end(), &error);
+ StringRef src = curHead();
+ int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
if (error)
return emitError(getLocation(), error);
@@ -617,21 +620,21 @@ llvm::FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
return res;
}
-llvm::FailureOr<uint32_t> ParserHead::parseVectorSize() {
+FailureOr<uint32_t> ParserHead::parseVectorSize() {
return parseLiteral<uint32_t>();
}
-inline llvm::FailureOr<uint32_t> ParserHead::parseUI32() {
+inline FailureOr<uint32_t> ParserHead::parseUI32() {
return parseLiteral<uint32_t>();
}
-inline llvm::FailureOr<int64_t> ParserHead::parseI64() {
+inline FailureOr<int64_t> ParserHead::parseI64() {
return parseLiteral<int64_t>();
}
template <std::byte opCode>
inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
- return emitError(*currentOpLoc, "Unknown instruction opcode: ")
+ return emitError(*currentOpLoc, "unknown instruction opcode: ")
<< static_cast<int>(opCode);
}
@@ -645,7 +648,7 @@ void ValueStack::dump() const {
// end of the vector. Iterate in reverse so that the first thing we print
// is the top of the stack.
size_t stackSize = size();
- for (size_t idx = 0 ; idx < stackSize ;) {
+ for (size_t idx = 0; idx < stackSize;) {
size_t actualIdx = stackSize - 1 - idx;
llvm::dbgs() << " ";
values[actualIdx].dump();
@@ -663,18 +666,17 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n");
if (operandTypes.size() > values.size())
return emitError(*opLoc,
- "Stack doesn't contain enough values. Trying to get ")
+ "stack doesn't contain enough values. Trying to get ")
<< operandTypes.size() << " operands on a stack containing only "
<< values.size() << " values.";
size_t stackIdxOffset = values.size() - operandTypes.size();
- llvm::SmallVector<Value> res{};
+ SmallVector<Value> res{};
res.reserve(operandTypes.size());
for (size_t i{0}; i < operandTypes.size(); ++i) {
Value operand = values[i + stackIdxOffset];
Type stackType = operand.getType();
if (stackType != operandTypes[i])
- return emitError(*opLoc,
- "Invalid operand type on stack. Expecting ")
+ return emitError(*opLoc, "invalid operand type on stack. Expecting ")
<< operandTypes[i] << ", value on stack is of type " << stackType
<< ".";
LLVM_DEBUG(llvm::dbgs() << " POP: " << operand << "\n");
@@ -690,9 +692,9 @@ LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
LLVM_DEBUG(llvm::dbgs() << " Elements(s) to push: " << results.size()
<< "\n");
LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n");
- for (auto val : results) {
+ for (Value val : results) {
if (!isWasmValueType(val.getType()))
- return emitError(*opLoc, "Invalid value type on stack: ")
+ return emitError(*opLoc, "invalid value type on stack: ")
<< val.getType();
LLVM_DEBUG(llvm::dbgs() << " PUSH: " << val << "\n");
values.push_back(val);
@@ -702,8 +704,9 @@ LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
return success();
}
-template<std::byte EndParseByte>
-parsed_inst_t ExpressionParser::parse(OpBuilder &builder, UniqueByte<EndParseByte> endByte) {
+template <std::byte EndParseByte>
+parsed_inst_t ExpressionParser::parse(OpBuilder &builder,
+ UniqueByte<EndParseByte> endByte) {
auto res = parse(builder, ByteSequence<EndParseByte>{});
if (failed(res))
return failure();
@@ -711,13 +714,13 @@ parsed_inst_t ExpressionParser::parse(OpBuilder &builder, UniqueByte<EndParseByt
}
template <std::byte... ExpressionParseEnd>
-llvm::FailureOr<ExpressionParser::ParseResultWithInfo>
+FailureOr<ExpressionParser::ParseResultWithInfo>
ExpressionParser::parse(OpBuilder &builder,
ByteSequence<ExpressionParseEnd...> parsingEndFilters) {
- llvm::SmallVector<Value> res;
+ SmallVector<Value> res;
for (;;) {
currentOpLoc = parser.getLocation();
- auto opCode = parser.consumeByte();
+ FailureOr<std::byte> opCode = parser.consumeByte();
if (failed(opCode))
return failure();
if (isValueOneOf(*opCode, parsingEndFilters))
@@ -732,7 +735,6 @@ ExpressionParser::parse(OpBuilder &builder,
}
}
-
template <typename T>
inline Type buildLiteralType(OpBuilder &);
@@ -766,7 +768,8 @@ inline Type buildLiteralType<double>(OpBuilder &builder) {
return builder.getF64Type();
}
-template<typename ValT, typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
+template <typename ValT,
+ typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
struct AttrHolder;
template <typename ValT>
@@ -779,7 +782,7 @@ struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
using type = FloatAttr;
};
-template<typename ValT>
+template <typename ValT>
using attr_holder_t = typename AttrHolder<ValT>::type;
template <typename ValT,
@@ -823,18 +826,18 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
return parseConstInst<double>(builder);
}
-
class WasmBinaryParser {
private:
struct SectionRegistry {
- using section_location_t = llvm::StringRef;
+ using section_location_t = StringRef;
- std::array<llvm::SmallVector<section_location_t>, highestWasmSectionID+1> registry;
+ std::array<SmallVector<section_location_t>, highestWasmSectionID + 1>
+ registry;
template <WasmSectionType SecType>
std::conditional_t<sectionShouldBeUnique(SecType),
std::optional<section_location_t>,
- llvm::ArrayRef<section_location_t>>
+ ArrayRef<section_location_t>>
getContentForSection() const {
constexpr auto idx = static_cast<size_t>(SecType);
if constexpr (sectionShouldBeUnique(SecType)) {
@@ -858,7 +861,7 @@ class WasmBinaryParser {
section_location_t location, Location loc) {
if (sectionShouldBeUnique(secType) && hasSection(secType))
return emitError(loc,
- "Trying to add a second instance of unique section");
+ "trying to add a second instance of unique section");
registry[static_cast<size_t>(secType)].push_back(location);
emitRemark(loc, "Adding section with section ID ")
@@ -868,26 +871,25 @@ class WasmBinaryParser {
LogicalResult populateFromBody(ParserHead ph) {
while (!ph.end()) {
- auto sectionLoc = ph.getLocation();
- auto secType = ph.parseWasmSectionType();
+ FileLineColLoc sectionLoc = ph.getLocation();
+ FailureOr<WasmSectionType> secType = ph.parseWasmSectionType();
if (failed(secType))
return failure();
- auto secSizeParsed = ph.parseLiteral<uint32_t>();
+ FailureOr<uint32_t> secSizeParsed = ph.parseLiteral<uint32_t>();
if (failed(secSizeParsed))
return failure();
- auto secSize = *secSizeParsed;
- auto sectionContent = ph.consumeNBytes(secSize);
+ uint32_t secSize = *secSizeParsed;
+ FailureOr<StringRef> sectionContent = ph.consumeNBytes(secSize);
if (failed(sectionContent))
return failure();
- auto registration =
+ LogicalResult registration =
registerSection(*secType, *sectionContent, sectionLoc);
if (failed(registration))
return failure();
-
}
return success();
}
@@ -917,10 +919,10 @@ class WasmBinaryParser {
auto secSrc = secContent.value();
ParserHead ph{secSrc, sectionNameAttr};
- auto nElemsParsed = ph.parseVectorSize();
+ FailureOr<uint32_t> nElemsParsed = ph.parseVectorSize();
if (failed(nElemsParsed))
return failure();
- auto nElems = *nElemsParsed;
+ uint32_t nElems = *nElemsParsed;
LLVM_DEBUG(llvm::dbgs() << "Starting to parse " << nElems
<< " items for section " << secName << ".\n");
for (size_t i = 0; i < nElems; ++i) {
@@ -929,31 +931,31 @@ class WasmBinaryParser {
}
if (!ph.end())
- return emitError(getLocation(), "Unparsed garbage at end of section ")
+ return emitError(getLocation(), "unparsed garbage at end of section ")
<< secName;
return success();
}
/// Handles the registration of a function import
- LogicalResult visitImport(Location loc, llvm::StringRef moduleName,
- llvm::StringRef importName, TypeIdxRecord tid) {
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, TypeIdxRecord tid) {
using llvm::Twine;
if (tid.id >= symbols.moduleFuncTypes.size())
- return emitError(loc, "Invalid type id: ")
+ return emitError(loc, "invalid type id: ")
<< tid.id << ". Only " << symbols.moduleFuncTypes.size()
<< " type registration.";
- auto type = symbols.moduleFuncTypes[tid.id];
- auto symbol = symbols.getNewFuncSymbolName();
- auto funcOp = builder.create<FuncImportOp>(
- loc, symbol, moduleName, importName, type);
+ FunctionType type = symbols.moduleFuncTypes[tid.id];
+ std::string symbol = symbols.getNewFuncSymbolName();
+ auto funcOp =
+ builder.create<FuncImportOp>(loc, symbol, moduleName, importName, type);
symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type});
return funcOp.verify();
}
/// Handles the registration of a memory import
- LogicalResult visitImport(Location loc, llvm::StringRef moduleName,
- llvm::StringRef importName, LimitType limitType) {
- auto symbol = symbols.getNewMemorySymbolName();
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, LimitType limitType) {
+ std::string symbol = symbols.getNewMemorySymbolName();
auto memOp = builder.create<MemImportOp>(loc, symbol, moduleName,
importName, limitType);
symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)});
@@ -961,9 +963,9 @@ class WasmBinaryParser {
}
/// Handles the registration of a table import
- LogicalResult visitImport(Location loc, llvm::StringRef moduleName,
- llvm::StringRef importName, TableType tableType) {
- auto symbol = symbols.getNewTableSymbolName();
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, TableType tableType) {
+ std::string symbol = symbols.getNewTableSymbolName();
auto tableOp = builder.create<TableImportOp>(loc, symbol, moduleName,
importName, tableType);
symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)});
@@ -971,14 +973,14 @@ class WasmBinaryParser {
}
/// Handles the registration of a global variable import
- LogicalResult visitImport(Location loc, llvm::StringRef moduleName,
- llvm::StringRef importName,
- GlobalTypeRecord globalType) {
- auto symbol = symbols.getNewGlobalSymbolName();
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, GlobalTypeRecord globalType) {
+ std::string symbol = symbols.getNewGlobalSymbolName();
auto giOp =
builder.create<GlobalImportOp>(loc, symbol, moduleName, importName,
globalType.type, globalType.isMutable);
- symbols.globalSymbols.push_back({{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
+ symbols.globalSymbols.push_back(
+ {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
return giOp.verify();
}
@@ -987,63 +989,62 @@ class WasmBinaryParser {
: builder{ctx}, ctx{ctx} {
ctx->loadAllAvailableDialects();
if (sourceMgr.getNumBuffers() != 1) {
- emitError(UnknownLoc::get(ctx), "One source file should be provided");
+ emitError(UnknownLoc::get(ctx), "one source file should be provided");
return;
}
- auto sourceBufId = sourceMgr.getMainFileID();
- auto source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
+ uint32_t sourceBufId = sourceMgr.getMainFileID();
+ StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
srcName = StringAttr::get(
- ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
+ ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
auto parser = ParserHead{source, srcName};
auto const wasmHeader = StringRef{"\0asm", 4};
- auto magicLoc = parser.getLocation();
- auto magic = parser.consumeNBytes(wasmHeader.size());
+ FileLineColLoc magicLoc = parser.getLocation();
+ FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
if (failed(magic) || magic->compare(wasmHeader)) {
- emitError(magicLoc,
- "Source file does not contain valid Wasm header.");
+ emitError(magicLoc, "source file does not contain valid Wasm header.");
return;
}
auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
- auto versionLoc = parser.getLocation();
- auto version = parser.consumeNBytes(expectedVersionString.size());
+ FileLineColLoc versionLoc = parser.getLocation();
+ FailureOr<StringRef> version =
+ parser.consumeNBytes(expectedVersionString.size());
if (failed(version))
return;
if (version->compare(expectedVersionString)) {
emitError(versionLoc,
- "Unsupported Wasm version. Only version 1 is supported.");
+ "unsupported Wasm version. Only version 1 is supported.");
return;
}
- auto fillRegistry = registry.populateFromBody(parser.copy());
+ LogicalResult fillRegistry = registry.populateFromBody(parser.copy());
if (failed(fillRegistry))
return;
mOp = builder.create<ModuleOp>(getLocation());
- builder.setInsertionPointToStart(
- &mOp.getBodyRegion().front());
- auto parsingTypes = parseSection<WasmSectionType::TYPE>();
+ builder.setInsertionPointToStart(&mOp.getBodyRegion().front());
+ LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>();
if (failed(parsingTypes))
return;
- auto parsingImports = parseSection<WasmSectionType::IMPORT>();
+ LogicalResult parsingImports = parseSection<WasmSectionType::IMPORT>();
if (failed(parsingImports))
return;
firstInternalFuncID = symbols.funcSymbols.size();
- auto parsingFunctions = parseSection<WasmSectionType::FUNCTION>();
+ LogicalResult parsingFunctions = parseSection<WasmSectionType::FUNCTION>();
if (failed(parsingFunctions))
return;
- auto parsingTables = parseSection<WasmSectionType::TABLE>();
+ LogicalResult parsingTables = parseSection<WasmSectionType::TABLE>();
if (failed(parsingTables))
return;
- auto parsingMems = parseSection<WasmSectionType::MEMORY>();
+ LogicalResult parsingMems = parseSection<WasmSectionType::MEMORY>();
if (failed(parsingMems))
return;
- auto parsingExports = parseSection<WasmSectionType::EXPORT>();
+ LogicalResult parsingExports = parseSection<WasmSectionType::EXPORT>();
if (failed(parsingExports))
return;
@@ -1068,8 +1069,9 @@ class WasmBinaryParser {
template <>
LogicalResult
-WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph, size_t) {
- auto importLoc = ph.getLocation();
+WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc importLoc = ph.getLocation();
auto moduleName = ph.parseName();
if (failed(moduleName))
return failure();
@@ -1078,7 +1080,7 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph, size
if (failed(importName))
return failure();
- auto import = ph.parseImportDesc(ctx);
+ FailureOr<ImportDesc> import = ph.parseImportDesc(ctx);
if (failed(import))
return failure();
@@ -1093,24 +1095,23 @@ template <>
LogicalResult
WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
size_t) {
- auto exportLoc = ph.getLocation();
+ FileLineColLoc exportLoc = ph.getLocation();
auto exportName = ph.parseName();
if (failed(exportName))
return failure();
- auto opcode = ph.consumeByte();
+ FailureOr<std::byte> opcode = ph.consumeByte();
if (failed(opcode))
return failure();
- auto idx = ph.parseLiteral<uint32_t>();
+ FailureOr<uint32_t> idx = ph.parseLiteral<uint32_t>();
if (failed(idx))
return failure();
- using SymbolRefDesc =
- std::variant<llvm::SmallVector<SymbolRefContainer>,
- llvm::SmallVector<GlobalSymbolRefContainer>,
- llvm::SmallVector<FunctionSymbolRefContainer>>;
+ using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>,
+ SmallVector<GlobalSymbolRefContainer>,
+ SmallVector<FunctionSymbolRefContainer>>;
SymbolRefDesc currentSymbolList;
std::string symbolType = "";
@@ -1132,7 +1133,7 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
currentSymbolList = symbols.globalSymbols;
break;
default:
- return emitError(exportLoc, "Invalid value for export type: ")
+ return emitError(exportLoc, "invalid value for export type: ")
<< std::to_integer<unsigned>(*opcode);
}
@@ -1142,7 +1143,7 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
emitError(
exportLoc,
llvm::formatv(
- "Trying to export {0} {1} which is undefined in this scope",
+ "trying to export {0} {1} which is undefined in this scope",
symbolType, *idx));
return failure();
}
@@ -1155,21 +1156,23 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol);
SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public);
- auto symName = SymbolTable::getSymbolName(op);
+ StringAttr symName = SymbolTable::getSymbolName(op);
return SymbolTable{mOp}.rename(symName, *exportName);
}
template <>
LogicalResult
-WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph, size_t) {
- auto opLocation = ph.getLocation();
- auto tableType = ph.parseTableType(ctx);
+WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc opLocation = ph.getLocation();
+ FailureOr<TableType> tableType = ph.parseTableType(ctx);
if (failed(tableType))
return failure();
LLVM_DEBUG(llvm::dbgs() << " Parsed table description: " << *tableType
<< '\n');
- auto symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
- auto tableOp = builder.create<TableOp>(opLocation, symbol.strref(), *tableType);
+ StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
+ auto tableOp =
+ builder.create<TableOp>(opLocation, symbol.strref(), *tableType);
symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)});
return success();
}
@@ -1178,17 +1181,17 @@ template <>
LogicalResult
WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph,
size_t) {
- auto opLoc = ph.getLocation();
+ FileLineColLoc opLoc = ph.getLocation();
auto typeIdxParsed = ph.parseLiteral<uint32_t>();
if (failed(typeIdxParsed))
return failure();
- auto typeIdx = *typeIdxParsed;
+ uint32_t typeIdx = *typeIdxParsed;
if (typeIdx >= symbols.moduleFuncTypes.size())
- return emitError(getLocation(), "Invalid type index: ") << typeIdx;
- auto symbol = symbols.getNewFuncSymbolName();
+ return emitError(getLocation(), "invalid type index: ") << typeIdx;
+ std::string symbol = symbols.getNewFuncSymbolName();
auto funcOp =
builder.create<FuncOp>(opLoc, symbol, symbols.moduleFuncTypes[typeIdx]);
- auto *block = funcOp.addEntryBlock();
+ Block *block = funcOp.addEntryBlock();
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToEnd(block);
builder.create<ReturnOp>(opLoc);
@@ -1203,7 +1206,7 @@ template <>
LogicalResult
WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
size_t) {
- auto funcType = ph.parseFunctionType(ctx);
+ FailureOr<FunctionType> funcType = ph.parseFunctionType(ctx);
if (failed(funcType))
return failure();
LLVM_DEBUG(llvm::dbgs() << "Parsed function type " << *funcType << '\n');
@@ -1213,30 +1216,29 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
template <>
LogicalResult
-WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph, size_t) {
- auto opLocation = ph.getLocation();
- auto memory = ph.parseLimit(ctx);
+WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc opLocation = ph.getLocation();
+ FailureOr<LimitType> memory = ph.parseLimit(ctx);
if (failed(memory))
return failure();
LLVM_DEBUG(llvm::dbgs() << " Registering memory " << *memory << '\n');
- auto symbol = symbols.getNewMemorySymbolName();
+ std::string symbol = symbols.getNewMemorySymbolName();
auto memOp = builder.create<MemOp>(opLocation, symbol, *memory);
symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)});
return success();
}
} // namespace
-namespace mlir {
-namespace wasm {
+namespace mlir::wasm {
OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source,
MLIRContext *context) {
WasmBinaryParser wBN{source, context};
- auto mOp = wBN.getModule();
+ ModuleOp mOp = wBN.getModule();
if (mOp)
return {mOp};
return {nullptr};
}
-} // namespace wasm
-} // namespace mlir
+} // namespace mlir::wasm
diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
index 9c0f7702a96aa..03b97846d45d3 100644
--- a/mlir/lib/Target/Wasm/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
@@ -11,18 +11,18 @@
#include "mlir/Target/Wasm/WasmImporter.h"
#include "mlir/Tools/mlir-translate/Translation.h"
-
using namespace mlir;
namespace mlir {
void registerFromWasmTranslation() {
TranslateToMLIRRegistration registration{
- "import-wasm", "Translate WASM to MLIR",
- [](llvm::SourceMgr &sourceMgr, MLIRContext* context) -> OwningOpRef<Operation*> {
- return wasm::importWebAssemblyToModule(sourceMgr, context);
- }, [](DialectRegistry& registry) {
- registry.insert<wasmssa::WasmSSADialect>();
- }
- };
+ "import-wasm", "Translate WASM to MLIR",
+ [](llvm::SourceMgr &sourceMgr,
+ MLIRContext *context) -> OwningOpRef<Operation *> {
+ return wasm::importWebAssemblyToModule(sourceMgr, context);
+ },
+ [](DialectRegistry ®istry) {
+ registry.insert<wasmssa::WasmSSADialect>();
+ }};
}
} // namespace mlir
diff --git a/mlir/test/Target/Wasm/bad_wasm_version.yaml b/mlir/test/Target/Wasm/bad_wasm_version.yaml
index 4fed1d5a3af3c..f834afbef679d 100644
--- a/mlir/test/Target/Wasm/bad_wasm_version.yaml
+++ b/mlir/test/Target/Wasm/bad_wasm_version.yaml
@@ -1,6 +1,6 @@
# RUN: yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s
-# CHECK: Unsupported Wasm version
+# CHECK: unsupported Wasm version
--- !WASM
FileHeader:
diff --git a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml
index ffb26f563141a..5adbd861bad36 100644
--- a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml
+++ b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml
@@ -2,7 +2,7 @@
# FIXME: The error code here should be nonzero.
-# CHECK: Trying to export function 42 which is undefined in this scope
+# CHECK: trying to export function 42 which is undefined in this scope
--- !WASM
FileHeader:
diff --git a/mlir/test/Target/Wasm/invalid_function_type_index.yaml b/mlir/test/Target/Wasm/invalid_function_type_index.yaml
index 961e9cc6e8029..b01a623c41209 100644
--- a/mlir/test/Target/Wasm/invalid_function_type_index.yaml
+++ b/mlir/test/Target/Wasm/invalid_function_type_index.yaml
@@ -1,5 +1,5 @@
# RUN: yaml2obj %s | mlir-translate --import-wasm -o - 2>&1 | FileCheck %s
-# CHECK: error: Invalid type index: 2
+# CHECK: error: invalid type index: 2
# FIXME: mlir-translate should not return 0 here.
diff --git a/mlir/test/Target/Wasm/missing_header.yaml b/mlir/test/Target/Wasm/missing_header.yaml
index 5610f9c5c6e33..a9f812e0a77f8 100644
--- a/mlir/test/Target/Wasm/missing_header.yaml
+++ b/mlir/test/Target/Wasm/missing_header.yaml
@@ -1,6 +1,6 @@
# RUN: not yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s
-# CHECK: Source file does not contain valid Wasm header
+# CHECK: source file does not contain valid Wasm header
--- !WASM
Sections:
>From 6f6c52785bedfaed8ef46426c9cb1fba7e566f9f Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Thu, 7 Aug 2025 09:25:34 +0000
Subject: [PATCH 07/14] [MLIR][WASM] Make WasmSSA Importer signal program
failure on error
Also contains non functional changes to use the `LDBG` macro.
---
mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 61 +++++++++++--------
.../Wasm/function_export_out_of_scope.yaml | 4 +-
.../Wasm/invalid_function_type_index.yaml | 4 +-
3 files changed, 38 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index dd8b86670c31a..ed3b43a4408ac 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -15,6 +15,7 @@
#include "mlir/Target/Wasm/WasmImporter.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/LEB128.h"
@@ -322,17 +323,16 @@ class ParserHead {
}
FailureOr<StringRef> consumeNBytes(size_t nBytes) {
- LLVM_DEBUG(llvm::dbgs() << "Consume " << nBytes << " bytes\n");
- LLVM_DEBUG(llvm::dbgs() << " Bytes remaining: " << size() << "\n");
- LLVM_DEBUG(llvm::dbgs() << " Current offset: " << offset << "\n");
+ LDBG() << "Consume " << nBytes << " bytes";
+ LDBG() << " Bytes remaining: " << size();
+ LDBG() << " Current offset: " << offset;
if (nBytes > size())
return emitError(getLocation(), "trying to extract ")
<< nBytes << "bytes when only " << size() << "are avilables";
StringRef res = head.slice(offset, offset + nBytes);
offset += nBytes;
- LLVM_DEBUG(llvm::dbgs()
- << " Updated offset (+" << nBytes << "): " << offset << "\n");
+ LDBG() << " Updated offset (+" << nBytes << "): " << offset;
return res;
}
@@ -660,10 +660,9 @@ void ValueStack::dump() const {
#endif
parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
- LLVM_DEBUG(llvm::dbgs() << "Popping from ValueStack\n");
- LLVM_DEBUG(llvm::dbgs() << " Elements(s) to pop: " << operandTypes.size()
- << "\n");
- LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n");
+ LDBG() << "Popping from ValueStack\n"
+ << " Elements(s) to pop: " << operandTypes.size() << "\n"
+ << " Current stack size: " << values.size();
if (operandTypes.size() > values.size())
return emitError(*opLoc,
"stack doesn't contain enough values. Trying to get ")
@@ -679,28 +678,27 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
return emitError(*opLoc, "invalid operand type on stack. Expecting ")
<< operandTypes[i] << ", value on stack is of type " << stackType
<< ".";
- LLVM_DEBUG(llvm::dbgs() << " POP: " << operand << "\n");
+ LDBG() << " POP: " << operand;
res.push_back(operand);
}
values.resize(values.size() - operandTypes.size());
- LLVM_DEBUG(llvm::dbgs() << " Updated stack size: " << values.size() << "\n");
+ LDBG() << " Updated stack size: " << values.size();
return res;
}
LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
- LLVM_DEBUG(llvm::dbgs() << "Pushing to ValueStack\n");
- LLVM_DEBUG(llvm::dbgs() << " Elements(s) to push: " << results.size()
- << "\n");
- LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n");
+ LDBG() << "Pushing to ValueStack\n"
+ << " Elements(s) to push: " << results.size() << "\n"
+ << " Current stack size: " << values.size();
for (Value val : results) {
if (!isWasmValueType(val.getType()))
return emitError(*opLoc, "invalid value type on stack: ")
<< val.getType();
- LLVM_DEBUG(llvm::dbgs() << " PUSH: " << val << "\n");
+ LDBG() << " PUSH: " << val;
values.push_back(val);
}
- LLVM_DEBUG(llvm::dbgs() << " Updated stack size: " << values.size() << "\n");
+ LDBG() << " Updated stack size: " << values.size();
return success();
}
@@ -913,7 +911,7 @@ class WasmBinaryParser {
};
auto secContent = registry.getContentForSection<section>();
if (!secContent) {
- LLVM_DEBUG(llvm::dbgs() << secName << " section is not present in file.");
+ LDBG() << secName << " section is not present in file.";
return success();
}
@@ -923,8 +921,8 @@ class WasmBinaryParser {
if (failed(nElemsParsed))
return failure();
uint32_t nElems = *nElemsParsed;
- LLVM_DEBUG(llvm::dbgs() << "Starting to parse " << nElems
- << " items for section " << secName << ".\n");
+ LDBG() << "Starting to parse " << nElems << " items for section "
+ << secName;
for (size_t i = 0; i < nElems; ++i) {
if (failed(parseSectionItem<section>(ph, i)))
return failure();
@@ -984,9 +982,18 @@ class WasmBinaryParser {
return giOp.verify();
}
+ // Detect occurence of errors
+ LogicalResult peekDiag(Diagnostic &diag) {
+ if (diag.getSeverity() == DiagnosticSeverity::Error)
+ isValid = false;
+ return failure();
+ }
+
public:
WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
: builder{ctx}, ctx{ctx} {
+ ctx->getDiagEngine().registerHandler(
+ [this](Diagnostic &diag) { return peekDiag(diag); });
ctx->loadAllAvailableDialects();
if (sourceMgr.getNumBuffers() != 1) {
emitError(UnknownLoc::get(ctx), "one source file should be provided");
@@ -1055,7 +1062,11 @@ class WasmBinaryParser {
numTableSectionItems = symbols.tableSymbols.size();
}
- ModuleOp getModule() { return mOp; }
+ ModuleOp getModule() {
+ if (isValid)
+ return mOp;
+ return ModuleOp{};
+ }
private:
mlir::StringAttr srcName;
@@ -1065,6 +1076,7 @@ class WasmBinaryParser {
ModuleOp mOp;
SectionRegistry registry;
size_t firstInternalFuncID{0};
+ bool isValid{true};
};
template <>
@@ -1168,8 +1180,7 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph,
FailureOr<TableType> tableType = ph.parseTableType(ctx);
if (failed(tableType))
return failure();
- LLVM_DEBUG(llvm::dbgs() << " Parsed table description: " << *tableType
- << '\n');
+ LDBG() << " Parsed table description: " << *tableType;
StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
auto tableOp =
builder.create<TableOp>(opLocation, symbol.strref(), *tableType);
@@ -1209,7 +1220,7 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
FailureOr<FunctionType> funcType = ph.parseFunctionType(ctx);
if (failed(funcType))
return failure();
- LLVM_DEBUG(llvm::dbgs() << "Parsed function type " << *funcType << '\n');
+ LDBG() << "Parsed function type " << *funcType;
symbols.moduleFuncTypes.push_back(*funcType);
return success();
}
@@ -1223,7 +1234,7 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph,
if (failed(memory))
return failure();
- LLVM_DEBUG(llvm::dbgs() << " Registering memory " << *memory << '\n');
+ LDBG() << " Registering memory " << *memory;
std::string symbol = symbols.getNewMemorySymbolName();
auto memOp = builder.create<MemOp>(opLocation, symbol, *memory);
symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)});
diff --git a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml
index 5adbd861bad36..b08c2c87abdb3 100644
--- a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml
+++ b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml
@@ -1,6 +1,4 @@
-# RUN: yaml2obj %s | mlir-translate --import-wasm -o - 2>&1 | FileCheck %s
-
-# FIXME: The error code here should be nonzero.
+# RUN: yaml2obj %s | not mlir-translate --import-wasm -o - 2>&1 | FileCheck %s
# CHECK: trying to export function 42 which is undefined in this scope
diff --git a/mlir/test/Target/Wasm/invalid_function_type_index.yaml b/mlir/test/Target/Wasm/invalid_function_type_index.yaml
index b01a623c41209..2d2954aa32dda 100644
--- a/mlir/test/Target/Wasm/invalid_function_type_index.yaml
+++ b/mlir/test/Target/Wasm/invalid_function_type_index.yaml
@@ -1,8 +1,6 @@
-# RUN: yaml2obj %s | mlir-translate --import-wasm -o - 2>&1 | FileCheck %s
+# RUN: yaml2obj %s | not mlir-translate --import-wasm -o - 2>&1 | FileCheck %s
# CHECK: error: invalid type index: 2
-# FIXME: mlir-translate should not return 0 here.
-
--- !WASM
FileHeader:
Version: 0x00000001
>From b58ddbe8f8292ff26445fdb2a683b912c7719828 Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Fri, 8 Aug 2025 02:20:42 +0000
Subject: [PATCH 08/14] [MLIR] Adding yaml2obj as dependency to mlir unit tests
This is due to the tests in the wasm importer requiring it.
---
mlir/test/CMakeLists.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 89568e7766ae5..c21e3610b9066 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -123,7 +123,7 @@ set(MLIR_TEST_DEPENDS
tblgen-to-irdl
)
if(NOT MLIR_STANDALONE_BUILD)
- list(APPEND MLIR_TEST_DEPENDS FileCheck count not split-file)
+ list(APPEND MLIR_TEST_DEPENDS FileCheck count not split-file yaml2obj)
endif()
set(MLIR_TEST_DEPENDS ${MLIR_TEST_DEPENDS}
>From ee8d3889ba59bd94ea202f3b6e4962df72ead4bd Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Tue, 12 Aug 2025 10:24:20 +0900
Subject: [PATCH 09/14] [WASM][MLIR] Importer tests uses new custom format
---
mlir/test/Target/Wasm/memory_min_eq_max.mlir | 2 +-
mlir/test/Target/Wasm/memory_min_max.mlir | 2 +-
mlir/test/Target/Wasm/memory_min_no_max.mlir | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Target/Wasm/memory_min_eq_max.mlir b/mlir/test/Target/Wasm/memory_min_eq_max.mlir
index 088e28685d09a..2ba5ab50d51fa 100644
--- a/mlir/test/Target/Wasm/memory_min_eq_max.mlir
+++ b/mlir/test/Target/Wasm/memory_min_eq_max.mlir
@@ -4,4 +4,4 @@
(module (memory 0 0))
*/
-// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa<limit[0: 0]>, sym_name = "mem_0", sym_visibility = "nested"}> : () -> ()
+// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[0: 0]>
diff --git a/mlir/test/Target/Wasm/memory_min_max.mlir b/mlir/test/Target/Wasm/memory_min_max.mlir
index 16d3468279d42..ebf64189189f8 100644
--- a/mlir/test/Target/Wasm/memory_min_max.mlir
+++ b/mlir/test/Target/Wasm/memory_min_max.mlir
@@ -4,4 +4,4 @@
(module (memory 0 65536))
*/
-// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa<limit[0: 65536]>, sym_name = "mem_0", sym_visibility = "nested"}> : () -> ()
+// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[0: 65536]>
diff --git a/mlir/test/Target/Wasm/memory_min_no_max.mlir b/mlir/test/Target/Wasm/memory_min_no_max.mlir
index f71cb1098be18..8d8878618bcc0 100644
--- a/mlir/test/Target/Wasm/memory_min_no_max.mlir
+++ b/mlir/test/Target/Wasm/memory_min_no_max.mlir
@@ -4,4 +4,4 @@
(module (memory 1))
*/
-// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa<limit[1:]>, sym_name = "mem_0", sym_visibility = "nested"}> : () -> ()
+// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[1:]>
>From 6e952fc74b578e07b334ed4d4e995cfaed4e7186 Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Tue, 12 Aug 2025 13:48:20 +0900
Subject: [PATCH 10/14] [MLIR][WASM] Implement review remarks
---
.../include/mlir/Target/Wasm/WasmBinaryEncoding.h | 15 ++++++++-------
mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 8 +++++++-
2 files changed, 15 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
index a5b124eecbe67..3280432b5f038 100644
--- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
+++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
@@ -4,19 +4,20 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Define encodings for WebAssembly instructions, types, etc from the
+// Define various flags used to encode instructions, types, etc. in
// WebAssembly binary format.
//
-// Each encoding is defined in the WebAssembly binary specification.
+// These encodings are defined in the WebAssembly binary format specification.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TARGET_WASMBINARYENCODING
#define MLIR_TARGET_WASMBINARYENCODING
#include <cstddef>
+
namespace mlir {
struct WasmBinaryEncoding {
- /// Byte encodings for WASM instructions.
+ /// Byte encodings for Wasm instructions.
struct OpCode {
// Locals, globals, constants.
static constexpr std::byte constI32{0x41};
@@ -25,7 +26,7 @@ struct WasmBinaryEncoding {
static constexpr std::byte constFP64{0x44};
};
- /// Byte encodings of types in WASM binaries
+ /// Byte encodings of types in Wasm binaries
struct Type {
static constexpr std::byte emptyBlockType{0x40};
static constexpr std::byte funcType{0x60};
@@ -38,7 +39,7 @@ struct WasmBinaryEncoding {
static constexpr std::byte i32{0x7F};
};
- /// Byte encodings of WASM imports.
+ /// Byte encodings of Wasm imports.
struct Import {
static constexpr std::byte typeID{0x00};
static constexpr std::byte tableType{0x01};
@@ -46,7 +47,7 @@ struct WasmBinaryEncoding {
static constexpr std::byte globalType{0x03};
};
- /// Byte encodings for WASM limits.
+ /// Byte encodings for Wasm limits.
struct LimitHeader {
static constexpr std::byte lowLimitOnly{0x00};
static constexpr std::byte bothLimits{0x01};
@@ -58,7 +59,7 @@ struct WasmBinaryEncoding {
static constexpr std::byte isMutable{0x01};
};
- /// Byte encodings describing WASM exports.
+ /// Byte encodings describing Wasm exports.
struct Export {
static constexpr std::byte function{0x00};
static constexpr std::byte table{0x01};
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index ed3b43a4408ac..a958a71e19018 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -5,6 +5,11 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+//
+// This file implements the WebAssembly importer.
+//
+//===----------------------------------------------------------------------===//
+
#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -19,6 +24,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/LEB128.h"
+#include <climits>
#include <cstdint>
#include <variant>
@@ -648,7 +654,7 @@ void ValueStack::dump() const {
// end of the vector. Iterate in reverse so that the first thing we print
// is the top of the stack.
size_t stackSize = size();
- for (size_t idx = 0; idx < stackSize;) {
+ for (size_t idx = 0; idx < stackSize; idx++) {
size_t actualIdx = stackSize - 1 - idx;
llvm::dbgs() << " ";
values[actualIdx].dump();
>From 5ebf0e1731b7bda60494d277b9ebdffb6f8d1d1b Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Tue, 12 Aug 2025 17:15:48 +0900
Subject: [PATCH 11/14] [MLIR][WASM] Fix template to please MSVC
---
mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index a958a71e19018..4dc1d425188ad 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -106,7 +106,7 @@ template <std::byte Byte>
struct UniqueByte : ByteSequence<Byte> {};
template <typename T, T... Values>
-constexpr ByteSequence<std::byte{Values}...>
+constexpr ByteSequence<std::byte(Values)...>
byteSeqFromIntSeq(std::integer_sequence<T, Values...>) {
return {};
}
>From 819e2971a9701cb79b5eb3ce4d7f95689406e729 Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Wed, 13 Aug 2025 13:41:10 +0900
Subject: [PATCH 12/14] [MLIR][WASM] NFC remove deadcode in WasmImporter
---
mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 17 +----------------
1 file changed, 1 insertion(+), 16 deletions(-)
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 4dc1d425188ad..d0fa70c26faee 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -1,4 +1,4 @@
-//===- TranslateFromWasm.cpp - Translating to C++ calls -------------------===//
+//===- TranslateFromWasm.cpp - Translating to WasmSSA dialect -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -95,25 +95,10 @@ constexpr bool sectionShouldBeUnique(WasmSectionType secType) {
template <std::byte... Bytes>
struct ByteSequence {};
-template <std::byte... Bytes1, std::byte... Bytes2>
-constexpr ByteSequence<Bytes1..., Bytes2...>
-operator+(ByteSequence<Bytes1...>, ByteSequence<Bytes2...>) {
- return {};
-}
-
/// Template class for representing a byte sequence of only one byte
template <std::byte Byte>
struct UniqueByte : ByteSequence<Byte> {};
-template <typename T, T... Values>
-constexpr ByteSequence<std::byte(Values)...>
-byteSeqFromIntSeq(std::integer_sequence<T, Values...>) {
- return {};
-}
-
-constexpr auto allOpCodes =
- byteSeqFromIntSeq(std::make_integer_sequence<int, 256>());
-
constexpr ByteSequence<
WasmBinaryEncoding::Type::i32, WasmBinaryEncoding::Type::i64,
WasmBinaryEncoding::Type::f32, WasmBinaryEncoding::Type::f64,
>From bff79f3af339387d8a8fa71869cf9b00aa859ecb Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Thu, 14 Aug 2025 10:06:14 +0900
Subject: [PATCH 13/14] [MLIR][WASM] NFC: fix typo and documentation
---
mlir/include/mlir/Target/Wasm/WasmImporter.h | 9 +++------
mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 2 +-
2 files changed, 4 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h
index 5cc42a1f32fa4..9b52f13e042df 100644
--- a/mlir/include/mlir/Target/Wasm/WasmImporter.h
+++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h
@@ -21,12 +21,9 @@
namespace mlir::wasm {
-/// Translates the given operation to C++ code. The operation or operations in
-/// the region of 'op' need almost all be in EmitC dialect. The parameter
-/// 'declareVariablesAtTop' enforces that all variables for op results and block
-/// arguments are declared at the beginning of the function.
-/// If parameter 'fileId' is non-empty, then body of `emitc.file` ops
-/// with matching id are emitted.
+/// If `source` contains a valid Wasm binary file, this function returns a
+/// a ModuleOp containing the representation of trhe Wasm module encoded in
+/// the source file in the `wasmssa` dialect.
OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source,
MLIRContext *context);
} // namespace mlir::wasm
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index d0fa70c26faee..c23a2915ef18e 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -319,7 +319,7 @@ class ParserHead {
LDBG() << " Current offset: " << offset;
if (nBytes > size())
return emitError(getLocation(), "trying to extract ")
- << nBytes << "bytes when only " << size() << "are avilables";
+ << nBytes << "bytes when only " << size() << "are available";
StringRef res = head.slice(offset, offset + nBytes);
offset += nBytes;
>From 3e2caec7eea3464921ef391953697b883eb31ddd Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Thu, 14 Aug 2025 14:25:15 +0900
Subject: [PATCH 14/14] [MLIR][WASM] NFC: fix typo
---
mlir/include/mlir/Target/Wasm/WasmImporter.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h
index 9b52f13e042df..3f05bbecefc8a 100644
--- a/mlir/include/mlir/Target/Wasm/WasmImporter.h
+++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h
@@ -22,7 +22,7 @@
namespace mlir::wasm {
/// If `source` contains a valid Wasm binary file, this function returns a
-/// a ModuleOp containing the representation of trhe Wasm module encoded in
+/// a ModuleOp containing the representation of the Wasm module encoded in
/// the source file in the `wasmssa` dialect.
OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source,
MLIRContext *context);
More information about the Mlir-commits
mailing list