[Mlir-commits] [mlir] [MLIR][WASM] Introduce the RaiseWasmMLIRPass to lower WasmSSA MLIR to core dialects (PR #164562)
Ferdinand Lemaire
llvmlistbot at llvm.org
Wed Oct 22 22:00:50 PDT 2025
================
@@ -0,0 +1,466 @@
+//===- RaiseWasmMLIR.cpp - Convert Wasm to less abstract dialects ---*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of wasm operations to standard dialects ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/RaiseWasm/RaiseWasmMLIR.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/LogicalResult.h"
+#include <optional>
+
+#define DEBUG_TYPE "wasm-convert"
+
+namespace mlir {
+#define GEN_PASS_DEF_RAISEWASMMLIR
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::wasmssa;
+namespace {
+
+template <typename SourceOp, typename TargetIntOp, typename TargetFPOp>
+struct IntFPDispatchMappingConversion : OpConversionPattern<SourceOp> {
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp srcOp, typename SourceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type type = srcOp.getRhs().getType();
+ if (type.isInteger()) {
+ rewriter.replaceOpWithNewOp<TargetIntOp>(srcOp, srcOp->getResultTypes(),
+ adaptor.getOperands());
+ return success();
+ }
+ if (!type.isFloat())
+ return failure();
+ rewriter.replaceOpWithNewOp<TargetFPOp>(srcOp, srcOp->getResultTypes(),
+ adaptor.getOperands());
+ return success();
+ }
+};
+
+using WasmAddOpConversion =
+ IntFPDispatchMappingConversion<AddOp, arith::AddIOp, arith::AddFOp>;
+using WasmMulOpConversion =
+ IntFPDispatchMappingConversion<MulOp, arith::MulIOp, arith::MulFOp>;
+using WasmSubOpConversion =
+ IntFPDispatchMappingConversion<SubOp, arith::SubIOp, arith::SubFOp>;
+
+/// Convert a k-ary source operation \p SourceOp into an operation \p TargetOp.
+/// Both \p SourceOp and \p TargetOp must have the same number of operands.
+template <typename SourceOp, typename TargetOp>
+struct OpMappingConversion : OpConversionPattern<SourceOp> {
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp srcOp, typename SourceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<TargetOp>(srcOp, srcOp->getResultTypes(),
+ adaptor.getOperands());
+ return success();
+ }
+};
+
+using WasmAndOpConversion = OpMappingConversion<AndOp, arith::AndIOp>;
+using WasmCeilOpConversion = OpMappingConversion<CeilOp, math::CeilOp>;
+/// TODO: SIToFP and UIToFP don't allow specification of the floating point
+/// rounding mode
+using WasmConvertSOpConversion =
+ OpMappingConversion<ConvertSOp, arith::SIToFPOp>;
+using WasmConvertUOpConversion =
+ OpMappingConversion<ConvertUOp, arith::UIToFPOp>;
+using WasmDemoteOpConversion = OpMappingConversion<DemoteOp, arith::TruncFOp>;
+using WasmDivFPOpConversion = OpMappingConversion<DivOp, arith::DivFOp>;
+using WasmDivSIOpConversion = OpMappingConversion<DivSIOp, arith::DivSIOp>;
+using WasmDivUIOpConversion = OpMappingConversion<DivUIOp, arith::DivUIOp>;
+using WasmExtendSOpConversion =
+ OpMappingConversion<ExtendSI32Op, arith::ExtSIOp>;
+using WasmExtendUOpConversion =
+ OpMappingConversion<ExtendUI32Op, arith::ExtUIOp>;
+using WasmFloorOpConversion = OpMappingConversion<FloorOp, math::FloorOp>;
+using WasmMaxOpConversion = OpMappingConversion<MaxOp, arith::MaximumFOp>;
+using WasmMinOpConversion = OpMappingConversion<MinOp, arith::MinimumFOp>;
+using WasmOrOpConversion = OpMappingConversion<OrOp, arith::OrIOp>;
+using WasmPromoteOpConversion = OpMappingConversion<PromoteOp, arith::ExtFOp>;
+using WasmRemSIOpConversion = OpMappingConversion<RemSIOp, arith::RemSIOp>;
+using WasmRemUIOpConversion = OpMappingConversion<RemUIOp, arith::RemUIOp>;
+using WasmReinterpretOpConversion =
+ OpMappingConversion<ReinterpretOp, arith::BitcastOp>;
+using WasmShLOpConversion = OpMappingConversion<ShLOp, arith::ShLIOp>;
+using WasmShRSOpConversion = OpMappingConversion<ShRSOp, arith::ShRSIOp>;
+using WasmShRUOpConversion = OpMappingConversion<ShRUOp, arith::ShRUIOp>;
+using WasmXOrOpConversion = OpMappingConversion<XOrOp, arith::XOrIOp>;
+using WasmNegOpConversion = OpMappingConversion<NegOp, arith::NegFOp>;
+using WasmCopySignOpConversion =
+ OpMappingConversion<CopySignOp, math::CopySignOp>;
+using WasmClzOpConversion =
+ OpMappingConversion<ClzOp, math::CountLeadingZerosOp>;
+using WasmCtzOpConversion =
+ OpMappingConversion<CtzOp, math::CountTrailingZerosOp>;
+using WasmPopCntOpConversion = OpMappingConversion<PopCntOp, math::CtPopOp>;
+using WasmAbsOpConversion = OpMappingConversion<AbsOp, math::AbsFOp>;
+using WasmTruncOpConversion = OpMappingConversion<TruncOp, math::TruncOp>;
+using WasmSqrtOpConversion = OpMappingConversion<SqrtOp, math::SqrtOp>;
+using WasmWrapOpConversion = OpMappingConversion<WrapOp, arith::TruncIOp>;
+
+struct WasmCallOpConversion : OpConversionPattern<FuncCallOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(FuncCallOp funcCallOp, FuncCallOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<func::CallOp>(
+ funcCallOp, funcCallOp.getCallee(), funcCallOp.getResults().getTypes(),
+ funcCallOp.getOperands());
+ return success();
+ }
+};
+
+struct WasmConstOpConversion : OpConversionPattern<ConstOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ConstOp constOp, ConstOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, constOp.getValue());
+ return success();
+ }
+};
+
+struct WasmFuncImportOpConversion : OpConversionPattern<FuncImportOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(FuncImportOp funcImportOp, FuncImportOp::Adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto nFunc = rewriter.replaceOpWithNewOp<func::FuncOp>(
+ funcImportOp, funcImportOp.getSymName(), funcImportOp.getType());
+ nFunc.setVisibility(SymbolTable::Visibility::Private);
+ return success();
+ }
+};
+
+struct WasmFuncOpConversion : OpConversionPattern<FuncOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(FuncOp funcOp, FuncOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto newFunc = rewriter.create<func::FuncOp>(
+ funcOp->getLoc(), funcOp.getSymName(), funcOp.getFunctionType());
+ rewriter.cloneRegionBefore(funcOp.getBody(), newFunc.getBody(),
+ newFunc.getBody().end());
+ Block *oldEntryBlock = &newFunc.getBody().front();
+ auto blockArgTypes = oldEntryBlock->getArgumentTypes();
+ TypeConverter::SignatureConversion sC{oldEntryBlock->getNumArguments()};
+ auto numArgs = blockArgTypes.size();
+ for (size_t i = 0; i < numArgs; ++i) {
+ auto argType = dyn_cast<LocalRefType>(blockArgTypes[i]);
+ if (!argType)
+ return failure();
+ sC.addInputs(i, argType.getElementType());
+ }
+
+ rewriter.applySignatureConversion(oldEntryBlock, sC, getTypeConverter());
+ rewriter.replaceOp(funcOp, newFunc);
+ return success();
+ }
+};
+
+struct WasmGlobalImportOpConverter : OpConversionPattern<GlobalImportOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(GlobalImportOp gIOp, GlobalImportOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefGOp = rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+ gIOp, gIOp.getSymNameAttr(), rewriter.getStringAttr("nested"),
+ TypeAttr::get(MemRefType::get({1}, gIOp.getType())), Attribute{},
+ /*constant*/ UnitAttr{},
+ /*alignment*/ IntegerAttr{});
+ memrefGOp.setConstant(!gIOp.getIsMutable());
+ return success();
+ }
+};
+
+template <typename CRTP, typename OriginOpType>
+struct GlobalOpConverter : OpConversionPattern<GlobalOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(GlobalOp globalOp, GlobalOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ReturnOp rop;
+ globalOp->walk([&rop](ReturnOp op) { rop = op; });
----------------
flemairen6 wrote:
Added a `getInitTerminator` method in the globalOp definition and removed this
https://github.com/llvm/llvm-project/pull/164562
More information about the Mlir-commits
mailing list