[Mlir-commits] [mlir] [MLIR][WASM] - Introduce an importer for Wasm binaries (PR #152131)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 5 05:24:32 PDT 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h mlir/include/mlir/Target/Wasm/WasmImporter.h mlir/lib/Target/Wasm/TranslateFromWasm.cpp mlir/lib/Target/Wasm/TranslateRegistration.cpp mlir/include/mlir/InitAllTranslations.h
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h
index 188c962db..5cc42a1f3 100644
--- a/mlir/include/mlir/Target/Wasm/WasmImporter.h
+++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h
@@ -27,7 +27,8 @@ namespace mlir::wasm {
 /// arguments are declared at the beginning of the function.
 /// If parameter 'fileId' is non-empty, then body of `emitc.file` ops
 /// with matching id are emitted.
-OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext* context);
+OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source,
+                                                MLIRContext *context);
 } // namespace mlir::wasm
 
 #endif // MLIR_TARGET_WASM_WASMIMPORTER_H
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 7fe59a49b..dd8b86670 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -29,7 +29,8 @@ STATISTIC(numGlobalSectionItems, "Parsed globals");
 STATISTIC(numMemorySectionItems, "Parsed memories");
 STATISTIC(numTableSectionItems, "Parsed tables");
 
-static_assert(CHAR_BIT == 8, "This code expects std::byte to be exactly 8 bits");
+static_assert(CHAR_BIT == 8,
+              "This code expects std::byte to be exactly 8 bits");
 
 using namespace mlir;
 using namespace mlir::wasm;
@@ -54,7 +55,7 @@ enum struct WasmSectionType : section_id_t {
 };
 
 constexpr section_id_t highestWasmSectionID{
-  static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
+    static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
 
 #define APPLY_WASM_SEC_TRANSFORM                                               \
   WASM_SEC_TRANSFORM(CUSTOM)                                                   \
@@ -85,7 +86,7 @@ constexpr bool sectionShouldBeUnique(WasmSectionType secType) {
 }
 
 template <std::byte... Bytes>
-struct ByteSequence{};
+struct ByteSequence {};
 
 template <std::byte... Bytes1, std::byte... Bytes2>
 constexpr ByteSequence<Bytes1..., Bytes2...>
@@ -94,7 +95,7 @@ operator+(ByteSequence<Bytes1...>, ByteSequence<Bytes2...>) {
 }
 
 /// Template class for representing a byte sequence of only one byte
-template<std::byte Byte>
+template <std::byte Byte>
 struct UniqueByte : ByteSequence<Byte> {};
 
 template <typename T, T... Values>
@@ -112,12 +113,13 @@ constexpr ByteSequence<
     WasmBinaryEncoding::Type::v128>
     valueTypesEncodings{};
 
-template<std::byte... allowedFlags>
-constexpr bool isValueOneOf(std::byte value, ByteSequence<allowedFlags...> = {}) {
-  return  ((value == allowedFlags) | ... | false);
+template <std::byte... allowedFlags>
+constexpr bool isValueOneOf(std::byte value,
+                            ByteSequence<allowedFlags...> = {}) {
+  return ((value == allowedFlags) | ... | false);
 }
 
-template<std::byte... flags>
+template <std::byte... flags>
 constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) {
   return !isValueOneOf<flags...>(value);
 }
@@ -143,7 +145,8 @@ struct FunctionSymbolRefContainer : SymbolRefContainer {
   FunctionType functionType;
 };
 
-using ImportDesc = std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
+using ImportDesc =
+    std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
 
 using parsed_inst_t = FailureOr<SmallVector<Value>>;
 
@@ -190,6 +193,7 @@ private:
     size_t stackIdx;
     LabelLevelOpInterface levelOp;
   };
+
 public:
   bool empty() const { return values.empty(); }
 
@@ -204,7 +208,7 @@ public:
   ///   if an error occurs.
   /// @return Failure or the vector of popped values.
   FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes,
-                                                        Location *opLoc);
+                                            Location *opLoc);
 
   /// Push the results of an operation to the stack so they can be used in a
   /// following operation.
@@ -214,7 +218,6 @@ public:
   ///   if an error occurs.
   LogicalResult pushResults(ValueRange results, Location *opLoc);
 
-
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// A simple dump function for debugging.
   /// Writes output to llvm::dbgs().
@@ -243,7 +246,6 @@ private:
   parseConstInst(OpBuilder &builder,
                  std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr);
 
-
   /// This function generates a dispatch tree to associate an opcode with a
   /// parser. Parsers are registered by specialising the
   /// `parseSpecificInstruction` function for the op code to handle.
@@ -266,8 +268,9 @@ private:
       constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
       constexpr size_t nextPatternBitSize = patternBitSize + 1;
       if ((opCode & bitSelect) != std::byte{0})
-        return dispatchToInstParser < nextPatternBitSize,
-               nextHighBitPatternStem | std::byte{1} > (opCode, builder);
+        return dispatchToInstParser<nextPatternBitSize,
+                                    nextHighBitPatternStem | std::byte{1}>(
+            opCode, builder);
       return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
           opCode, builder);
     } else {
@@ -281,35 +284,35 @@ private:
   };
 
 public:
-  template<std::byte ParseEndByte = WasmBinaryEncoding::endByte>
-  parsed_inst_t parse(OpBuilder &builder,
-                      UniqueByte<ParseEndByte> = {});
+  template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
+  parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
 
   template <std::byte... ExpressionParseEnd>
   FailureOr<ParseResultWithInfo>
   parse(OpBuilder &builder,
         ByteSequence<ExpressionParseEnd...> parsingEndFilters);
 
-  FailureOr<SmallVector<Value>>
-  popOperands(TypeRange operandTypes) {
+  FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) {
     return valueStack.popOperands(operandTypes, &currentOpLoc.value());
   }
 
   LogicalResult pushResults(ValueRange results) {
     return valueStack.pushResults(results, &currentOpLoc.value());
   }
+
 private:
   std::optional<Location> currentOpLoc;
   ParserHead &parser;
   WasmModuleSymbolTables const &symbols;
   locals_t locals;
   ValueStack valueStack;
-  };
+};
 
 class ParserHead {
 public:
   ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {}
   ParserHead(ParserHead &&) = default;
+
 private:
   ParserHead(ParserHead const &other) = default;
 
@@ -458,8 +461,7 @@ public:
       return failure();
     if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType)
       return emitError(typeLoc, "invalid function type header byte. Expecting ")
-             << std::to_integer<unsigned>(
-                    WasmBinaryEncoding::Type::funcType)
+             << std::to_integer<unsigned>(WasmBinaryEncoding::Type::funcType)
              << " got " << std::to_integer<unsigned>(*funcTypeHeader);
     FailureOr<TupleType> inputTypes = parseResultType(ctx);
     if (failed(inputTypes))
@@ -525,9 +527,7 @@ public:
 
   bool end() const { return curHead().empty(); }
 
-  ParserHead copy() const {
-    return *this;
-  }
+  ParserHead copy() const { return *this; }
 
 private:
   StringRef curHead() const { return head.drop_front(offset); }
@@ -575,7 +575,7 @@ FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
   unsigned encodingSize{0};
   StringRef src = curHead();
   uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
-                                     src.bytes_end(), &error);
+                                         src.bytes_end(), &error);
   if (error)
     return emitError(getLocation(), error);
 
@@ -594,7 +594,7 @@ FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
   unsigned encodingSize{0};
   StringRef src = curHead();
   int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
-                                     src.bytes_end(), &error);
+                                        src.bytes_end(), &error);
   if (error)
     return emitError(getLocation(), error);
   if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) ||
@@ -612,7 +612,7 @@ FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
   unsigned encodingSize{0};
   StringRef src = curHead();
   int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
-                                 src.bytes_end(), &error);
+                                    src.bytes_end(), &error);
   if (error)
     return emitError(getLocation(), error);
 
@@ -648,7 +648,7 @@ void ValueStack::dump() const {
   // end of the vector. Iterate in reverse so that the first thing we print
   // is the top of the stack.
   size_t stackSize = size();
-  for (size_t idx = 0 ; idx < stackSize ;) {
+  for (size_t idx = 0; idx < stackSize;) {
     size_t actualIdx = stackSize - 1 - idx;
     llvm::dbgs() << "  ";
     values[actualIdx].dump();
@@ -676,8 +676,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
     Value operand = values[i + stackIdxOffset];
     Type stackType = operand.getType();
     if (stackType != operandTypes[i])
-      return emitError(*opLoc,
-                       "invalid operand type on stack. Expecting ")
+      return emitError(*opLoc, "invalid operand type on stack. Expecting ")
              << operandTypes[i] << ", value on stack is of type " << stackType
              << ".";
     LLVM_DEBUG(llvm::dbgs() << "    POP: " << operand << "\n");
@@ -705,8 +704,9 @@ LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
   return success();
 }
 
-template<std::byte EndParseByte>
-parsed_inst_t ExpressionParser::parse(OpBuilder &builder, UniqueByte<EndParseByte> endByte) {
+template <std::byte EndParseByte>
+parsed_inst_t ExpressionParser::parse(OpBuilder &builder,
+                                      UniqueByte<EndParseByte> endByte) {
   auto res = parse(builder, ByteSequence<EndParseByte>{});
   if (failed(res))
     return failure();
@@ -735,7 +735,6 @@ ExpressionParser::parse(OpBuilder &builder,
   }
 }
 
-
 template <typename T>
 inline Type buildLiteralType(OpBuilder &);
 
@@ -769,7 +768,8 @@ inline Type buildLiteralType<double>(OpBuilder &builder) {
   return builder.getF64Type();
 }
 
-template<typename ValT, typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
+template <typename ValT,
+          typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
 struct AttrHolder;
 
 template <typename ValT>
@@ -782,7 +782,7 @@ struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
   using type = FloatAttr;
 };
 
-template<typename ValT>
+template <typename ValT>
 using attr_holder_t = typename AttrHolder<ValT>::type;
 
 template <typename ValT,
@@ -826,13 +826,13 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
   return parseConstInst<double>(builder);
 }
 
-
 class WasmBinaryParser {
 private:
   struct SectionRegistry {
     using section_location_t = StringRef;
 
-    std::array<SmallVector<section_location_t>, highestWasmSectionID+1> registry;
+    std::array<SmallVector<section_location_t>, highestWasmSectionID + 1>
+        registry;
 
     template <WasmSectionType SecType>
     std::conditional_t<sectionShouldBeUnique(SecType),
@@ -890,7 +890,6 @@ private:
 
         if (failed(registration))
           return failure();
-
       }
       return success();
     }
@@ -947,8 +946,8 @@ private:
              << " type registration.";
     FunctionType type = symbols.moduleFuncTypes[tid.id];
     std::string symbol = symbols.getNewFuncSymbolName();
-    auto funcOp = builder.create<FuncImportOp>(
-        loc, symbol, moduleName, importName, type);
+    auto funcOp =
+        builder.create<FuncImportOp>(loc, symbol, moduleName, importName, type);
     symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type});
     return funcOp.verify();
   }
@@ -975,13 +974,13 @@ private:
 
   /// Handles the registration of a global variable import
   LogicalResult visitImport(Location loc, StringRef moduleName,
-                            StringRef importName,
-                            GlobalTypeRecord globalType) {
+                            StringRef importName, GlobalTypeRecord globalType) {
     std::string symbol = symbols.getNewGlobalSymbolName();
     auto giOp =
         builder.create<GlobalImportOp>(loc, symbol, moduleName, importName,
                                        globalType.type, globalType.isMutable);
-    symbols.globalSymbols.push_back({{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
+    symbols.globalSymbols.push_back(
+        {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
     return giOp.verify();
   }
 
@@ -996,20 +995,20 @@ public:
     uint32_t sourceBufId = sourceMgr.getMainFileID();
     StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
     srcName = StringAttr::get(
-      ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
+        ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
 
     auto parser = ParserHead{source, srcName};
     auto const wasmHeader = StringRef{"\0asm", 4};
     FileLineColLoc magicLoc = parser.getLocation();
     FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
     if (failed(magic) || magic->compare(wasmHeader)) {
-      emitError(magicLoc,
-                "source file does not contain valid Wasm header.");
+      emitError(magicLoc, "source file does not contain valid Wasm header.");
       return;
     }
     auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
     FileLineColLoc versionLoc = parser.getLocation();
-    FailureOr<StringRef> version = parser.consumeNBytes(expectedVersionString.size());
+    FailureOr<StringRef> version =
+        parser.consumeNBytes(expectedVersionString.size());
     if (failed(version))
       return;
     if (version->compare(expectedVersionString)) {
@@ -1022,8 +1021,7 @@ public:
       return;
 
     mOp = builder.create<ModuleOp>(getLocation());
-    builder.setInsertionPointToStart(
-        &mOp.getBodyRegion().front());
+    builder.setInsertionPointToStart(&mOp.getBodyRegion().front());
     LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>();
     if (failed(parsingTypes))
       return;
@@ -1071,7 +1069,8 @@ private:
 
 template <>
 LogicalResult
-WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph, size_t) {
+WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph,
+                                                            size_t) {
   FileLineColLoc importLoc = ph.getLocation();
   auto moduleName = ph.parseName();
   if (failed(moduleName))
@@ -1110,10 +1109,9 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
   if (failed(idx))
     return failure();
 
-  using SymbolRefDesc =
-      std::variant<SmallVector<SymbolRefContainer>,
-                   SmallVector<GlobalSymbolRefContainer>,
-                   SmallVector<FunctionSymbolRefContainer>>;
+  using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>,
+                                     SmallVector<GlobalSymbolRefContainer>,
+                                     SmallVector<FunctionSymbolRefContainer>>;
 
   SymbolRefDesc currentSymbolList;
   std::string symbolType = "";
@@ -1164,7 +1162,8 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
 
 template <>
 LogicalResult
-WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph, size_t) {
+WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph,
+                                                           size_t) {
   FileLineColLoc opLocation = ph.getLocation();
   FailureOr<TableType> tableType = ph.parseTableType(ctx);
   if (failed(tableType))
@@ -1172,7 +1171,8 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph, size_
   LLVM_DEBUG(llvm::dbgs() << "  Parsed table description: " << *tableType
                           << '\n');
   StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
-  auto tableOp = builder.create<TableOp>(opLocation, symbol.strref(), *tableType);
+  auto tableOp =
+      builder.create<TableOp>(opLocation, symbol.strref(), *tableType);
   symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)});
   return success();
 }
@@ -1216,7 +1216,8 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
 
 template <>
 LogicalResult
-WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph, size_t) {
+WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph,
+                                                            size_t) {
   FileLineColLoc opLocation = ph.getLocation();
   FailureOr<LimitType> memory = ph.parseLimit(ctx);
   if (failed(memory))
diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
index 9c0f7702a..03b97846d 100644
--- a/mlir/lib/Target/Wasm/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
@@ -11,18 +11,18 @@
 #include "mlir/Target/Wasm/WasmImporter.h"
 #include "mlir/Tools/mlir-translate/Translation.h"
 
-
 using namespace mlir;
 
 namespace mlir {
 void registerFromWasmTranslation() {
   TranslateToMLIRRegistration registration{
-    "import-wasm", "Translate WASM to MLIR",
-    [](llvm::SourceMgr &sourceMgr, MLIRContext* context) -> OwningOpRef<Operation*> {
-      return wasm::importWebAssemblyToModule(sourceMgr, context);
-    }, [](DialectRegistry& registry) {
-      registry.insert<wasmssa::WasmSSADialect>();
-    }
-  };
+      "import-wasm", "Translate WASM to MLIR",
+      [](llvm::SourceMgr &sourceMgr,
+         MLIRContext *context) -> OwningOpRef<Operation *> {
+        return wasm::importWebAssemblyToModule(sourceMgr, context);
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<wasmssa::WasmSSADialect>();
+      }};
 }
 } // namespace mlir

``````````

</details>


https://github.com/llvm/llvm-project/pull/152131


More information about the Mlir-commits mailing list