[Mlir-commits] [mlir] [MLIR][WASM] Introduce the RaiseWasmMLIRPass to lower WasmSSA MLIR to core dialects (PR #164562)

Ferdinand Lemaire llvmlistbot at llvm.org
Wed Oct 22 22:00:50 PDT 2025


================
@@ -0,0 +1,466 @@
+//===- RaiseWasmMLIR.cpp - Convert Wasm to less abstract dialects ---*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of wasm operations to standard dialects ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/RaiseWasm/RaiseWasmMLIR.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/LogicalResult.h"
+#include <optional>
+
+#define DEBUG_TYPE "wasm-convert"
+
+namespace mlir {
+#define GEN_PASS_DEF_RAISEWASMMLIR
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::wasmssa;
+namespace {
+
+template <typename SourceOp, typename TargetIntOp, typename TargetFPOp>
+struct IntFPDispatchMappingConversion : OpConversionPattern<SourceOp> {
+  using OpConversionPattern<SourceOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(SourceOp srcOp, typename SourceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type type = srcOp.getRhs().getType();
+    if (type.isInteger()) {
+      rewriter.replaceOpWithNewOp<TargetIntOp>(srcOp, srcOp->getResultTypes(),
+                                               adaptor.getOperands());
+      return success();
+    }
+    if (!type.isFloat())
+      return failure();
+    rewriter.replaceOpWithNewOp<TargetFPOp>(srcOp, srcOp->getResultTypes(),
+                                            adaptor.getOperands());
+    return success();
+  }
+};
+
+using WasmAddOpConversion =
+    IntFPDispatchMappingConversion<AddOp, arith::AddIOp, arith::AddFOp>;
+using WasmMulOpConversion =
+    IntFPDispatchMappingConversion<MulOp, arith::MulIOp, arith::MulFOp>;
+using WasmSubOpConversion =
+    IntFPDispatchMappingConversion<SubOp, arith::SubIOp, arith::SubFOp>;
+
+/// Convert a k-ary source operation \p SourceOp into an operation \p TargetOp.
+/// Both \p SourceOp and \p TargetOp must have the same number of operands.
+template <typename SourceOp, typename TargetOp>
+struct OpMappingConversion : OpConversionPattern<SourceOp> {
+  using OpConversionPattern<SourceOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(SourceOp srcOp, typename SourceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<TargetOp>(srcOp, srcOp->getResultTypes(),
+                                          adaptor.getOperands());
+    return success();
+  }
+};
+
+using WasmAndOpConversion = OpMappingConversion<AndOp, arith::AndIOp>;
+using WasmCeilOpConversion = OpMappingConversion<CeilOp, math::CeilOp>;
+/// TODO: SIToFP and UIToFP don't allow specification of the floating point
+/// rounding mode
+using WasmConvertSOpConversion =
+    OpMappingConversion<ConvertSOp, arith::SIToFPOp>;
+using WasmConvertUOpConversion =
+    OpMappingConversion<ConvertUOp, arith::UIToFPOp>;
+using WasmDemoteOpConversion = OpMappingConversion<DemoteOp, arith::TruncFOp>;
+using WasmDivFPOpConversion = OpMappingConversion<DivOp, arith::DivFOp>;
+using WasmDivSIOpConversion = OpMappingConversion<DivSIOp, arith::DivSIOp>;
+using WasmDivUIOpConversion = OpMappingConversion<DivUIOp, arith::DivUIOp>;
+using WasmExtendSOpConversion =
+    OpMappingConversion<ExtendSI32Op, arith::ExtSIOp>;
+using WasmExtendUOpConversion =
+    OpMappingConversion<ExtendUI32Op, arith::ExtUIOp>;
+using WasmFloorOpConversion = OpMappingConversion<FloorOp, math::FloorOp>;
+using WasmMaxOpConversion = OpMappingConversion<MaxOp, arith::MaximumFOp>;
+using WasmMinOpConversion = OpMappingConversion<MinOp, arith::MinimumFOp>;
+using WasmOrOpConversion = OpMappingConversion<OrOp, arith::OrIOp>;
+using WasmPromoteOpConversion = OpMappingConversion<PromoteOp, arith::ExtFOp>;
+using WasmRemSIOpConversion = OpMappingConversion<RemSIOp, arith::RemSIOp>;
+using WasmRemUIOpConversion = OpMappingConversion<RemUIOp, arith::RemUIOp>;
+using WasmReinterpretOpConversion =
+    OpMappingConversion<ReinterpretOp, arith::BitcastOp>;
+using WasmShLOpConversion = OpMappingConversion<ShLOp, arith::ShLIOp>;
+using WasmShRSOpConversion = OpMappingConversion<ShRSOp, arith::ShRSIOp>;
+using WasmShRUOpConversion = OpMappingConversion<ShRUOp, arith::ShRUIOp>;
+using WasmXOrOpConversion = OpMappingConversion<XOrOp, arith::XOrIOp>;
+using WasmNegOpConversion = OpMappingConversion<NegOp, arith::NegFOp>;
+using WasmCopySignOpConversion =
+    OpMappingConversion<CopySignOp, math::CopySignOp>;
+using WasmClzOpConversion =
+    OpMappingConversion<ClzOp, math::CountLeadingZerosOp>;
+using WasmCtzOpConversion =
+    OpMappingConversion<CtzOp, math::CountTrailingZerosOp>;
+using WasmPopCntOpConversion = OpMappingConversion<PopCntOp, math::CtPopOp>;
+using WasmAbsOpConversion = OpMappingConversion<AbsOp, math::AbsFOp>;
+using WasmTruncOpConversion = OpMappingConversion<TruncOp, math::TruncOp>;
+using WasmSqrtOpConversion = OpMappingConversion<SqrtOp, math::SqrtOp>;
+using WasmWrapOpConversion = OpMappingConversion<WrapOp, arith::TruncIOp>;
+
+struct WasmCallOpConversion : OpConversionPattern<FuncCallOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(FuncCallOp funcCallOp, FuncCallOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<func::CallOp>(
+        funcCallOp, funcCallOp.getCallee(), funcCallOp.getResults().getTypes(),
+        funcCallOp.getOperands());
+    return success();
+  }
+};
+
+struct WasmConstOpConversion : OpConversionPattern<ConstOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstOp constOp, ConstOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, constOp.getValue());
+    return success();
+  }
+};
+
+struct WasmFuncImportOpConversion : OpConversionPattern<FuncImportOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(FuncImportOp funcImportOp, FuncImportOp::Adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto nFunc = rewriter.replaceOpWithNewOp<func::FuncOp>(
+        funcImportOp, funcImportOp.getSymName(), funcImportOp.getType());
+    nFunc.setVisibility(SymbolTable::Visibility::Private);
+    return success();
+  }
+};
+
+struct WasmFuncOpConversion : OpConversionPattern<FuncOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(FuncOp funcOp, FuncOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto newFunc = rewriter.create<func::FuncOp>(
+        funcOp->getLoc(), funcOp.getSymName(), funcOp.getFunctionType());
+    rewriter.cloneRegionBefore(funcOp.getBody(), newFunc.getBody(),
+                               newFunc.getBody().end());
+    Block *oldEntryBlock = &newFunc.getBody().front();
+    auto blockArgTypes = oldEntryBlock->getArgumentTypes();
+    TypeConverter::SignatureConversion sC{oldEntryBlock->getNumArguments()};
+    auto numArgs = blockArgTypes.size();
+    for (size_t i = 0; i < numArgs; ++i) {
+      auto argType = dyn_cast<LocalRefType>(blockArgTypes[i]);
+      if (!argType)
+        return failure();
+      sC.addInputs(i, argType.getElementType());
+    }
+
+    rewriter.applySignatureConversion(oldEntryBlock, sC, getTypeConverter());
+    rewriter.replaceOp(funcOp, newFunc);
+    return success();
+  }
+};
+
+struct WasmGlobalImportOpConverter : OpConversionPattern<GlobalImportOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(GlobalImportOp gIOp, GlobalImportOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto memrefGOp = rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+        gIOp, gIOp.getSymNameAttr(), rewriter.getStringAttr("nested"),
+        TypeAttr::get(MemRefType::get({1}, gIOp.getType())), Attribute{},
+        /*constant*/ UnitAttr{},
+        /*alignment*/ IntegerAttr{});
+    memrefGOp.setConstant(!gIOp.getIsMutable());
+    return success();
+  }
+};
+
+template <typename CRTP, typename OriginOpType>
+struct GlobalOpConverter : OpConversionPattern<GlobalOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(GlobalOp globalOp, GlobalOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    ReturnOp rop;
+    globalOp->walk([&rop](ReturnOp op) { rop = op; });
----------------
flemairen6 wrote:

Added a `getInitTerminator` method in the globalOp definition and removed this 

https://github.com/llvm/llvm-project/pull/164562


More information about the Mlir-commits mailing list