[Mlir-commits] [mlir] [MLIR][Wasm] Introduce the WasmSSA MLIR dialect (PR #149233)
Mehdi Amini
llvmlistbot at llvm.org
Wed Jul 23 14:47:22 PDT 2025
================
@@ -0,0 +1,510 @@
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/Support/Casting.h"
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp.inc"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/Support/LogicalResult.h"
+
+using namespace mlir;
+using namespace mlir::wasmssa;
+
+namespace {
+inline LogicalResult
+inferTeeGetResType(ValueRange operands,
+ ::llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands.empty())
+ return failure();
+ auto opType = llvm::dyn_cast<LocalRefType>(operands.front().getType());
+ if (!opType)
+ return failure();
+ inferredReturnTypes.push_back(opType.getElementType());
+ return success();
+}
+
+ParseResult parseImportOp(OpAsmParser &parser, OperationState &result) {
+ std::string importName;
+ auto *ctx = parser.getContext();
+ ParseResult res = parser.parseString(&importName);
+ result.addAttribute("importName", StringAttr::get(ctx, importName));
+
+ std::string fromStr;
+ res = parser.parseKeywordOrString(&fromStr);
+ if (failed(res) || fromStr != "from")
+ return failure();
+
+ std::string moduleName;
+ res = parser.parseString(&moduleName);
+ if (failed(res))
+ return failure();
+ result.addAttribute("moduleName", StringAttr::get(ctx, moduleName));
+
+ std::string asStr;
+ res = parser.parseKeywordOrString(&asStr);
+ if (failed(res) || asStr != "as")
+ return failure();
+
+ StringAttr symbolName;
+ res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
+ result.attributes);
+ return res;
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// BlockOp
+//===----------------------------------------------------------------------===//
+
+Block *BlockOp::getLabelTarget() { return getTarget(); }
+
+//===----------------------------------------------------------------------===//
+// BlockReturnOp
+//===----------------------------------------------------------------------===//
+
+std::size_t BlockReturnOp::getExitLevel() { return 0; }
+
+Block *BlockReturnOp::getTarget() {
+ return cast<WasmSSALabelBranchingInterface>(getOperation())
+ .getTargetOp()
+ .getOperation()
+ ->getSuccessor(0);
+}
+
+//===----------------------------------------------------------------------===//
+// ExtendLowBitsSOp
+//===----------------------------------------------------------------------===//
+
+ParseResult ExtendLowBitsSOp::parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ OpAsmParser::UnresolvedOperand operand;
+ uint64_t nBits;
+ auto parseRes = parser.parseInteger(nBits);
+ parseRes = parser.parseKeyword("low");
+ parseRes = parser.parseKeyword("bits");
+ parseRes = parser.parseKeyword("from");
+ parseRes = parser.parseOperand(operand);
+ parseRes = parser.parseColon();
+ Type inType;
+ parseRes = parser.parseType(inType);
+ if (!inType.isInteger())
+ return failure();
+ llvm::SmallVector<Value, 1> opVal;
+ parseRes = parser.resolveOperand(operand, inType, opVal);
+ if (parseRes.failed())
+ return failure();
+ result.addOperands(opVal);
+ result.addAttribute(
+ ExtendLowBitsSOp::getBitsToTakeAttrName(OperationName{
+ ExtendLowBitsSOp::getOperationName(), parser.getContext()}),
+ parser.getBuilder().getI64IntegerAttr(nBits));
+ result.addTypes(inType);
+ return success();
+}
+
+void ExtendLowBitsSOp::print(OpAsmPrinter &p) {
+ p << " " << getBitsToTake().getUInt() << " low bits from ";
+ p.printOperand(getInput());
+ p << ": " << getInput().getType();
+}
+
+LogicalResult ExtendLowBitsSOp::verify() {
+ auto bitsToTake = getBitsToTake().getValue().getLimitedValue();
+ if (bitsToTake != 32 && bitsToTake != 16 && bitsToTake != 8)
+ return emitError("Extend op can only take 8, 16 or 32 bits. Got ")
+ << bitsToTake;
+
+ if (bitsToTake >= getInput().getType().getIntOrFloatBitWidth())
+ return emitError("Trying to extend the ")
+ << bitsToTake << " low bits from a " << getInput().getType()
+ << " value";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+Block *FuncOp::addEntryBlock() {
+ if (!getBody().empty()) {
+ emitError("Adding entry block to a FuncOp which already has one.");
+ return &getBody().front();
+ }
+ Block &block = getBody().emplaceBlock();
+ for (auto argType : getFunctionType().getInputs())
+ block.addArgument(LocalRefType::get(argType), getLoc());
+ return █
+}
+
+void FuncOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, llvm::StringRef symbol,
+ FunctionType funcType) {
+ odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
+ odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
+ odsState.addAttribute("functionType", TypeAttr::get(funcType));
+ odsState.addRegion();
+}
+
+ParseResult FuncOp::parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes,
+ ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) {
+ llvm::SmallVector<Type> argTypesWithoutLocal{};
+ argTypesWithoutLocal.reserve(argTypes.size());
+ llvm::for_each(argTypes, [&parser, &argTypesWithoutLocal](Type argType) {
+ auto refType = dyn_cast<LocalRefType>(argType);
+ auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+ if (!refType) {
+ mlir::emitError(loc, "Invalid type for wasm.func argument. Expecting "
+ "!wasm<local T>, got ")
+ << argType << ".";
+ return;
+ }
+ argTypesWithoutLocal.push_back(refType.getElementType());
+ });
+
+ return builder.getFunctionType(argTypesWithoutLocal, results);
+ };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+LogicalResult FuncOp::verifyBody() {
+ if (getBody().empty())
+ return success();
+ Block &entry = getBody().front();
+ if (entry.getNumArguments() != getFunctionType().getNumInputs())
+ return emitError("Entry block should have same number of arguments as "
+ "function type. Function type has ")
+ << getFunctionType().getNumInputs() << ", entry block has "
+ << entry.getNumArguments() << ".";
+
+ for (auto [argNo, funcSignatureType, blockType] : llvm::enumerate(
+ getFunctionType().getInputs(), entry.getArgumentTypes())) {
+ auto blockLocalRefType = dyn_cast<LocalRefType>(blockType);
+ if (!blockLocalRefType)
+ return emitError("Entry block argument type should be LocalRefType, got ")
+ << blockType << " for block argument " << argNo << ".";
+ if (blockLocalRefType.getElementType() != funcSignatureType)
+ return emitError("Func argument type #")
+ << argNo << "(" << funcSignatureType
+ << ") doesn't match entry block referenced type ("
+ << blockLocalRefType.getElementType() << ").";
+ }
+ return success();
+}
+
+void FuncOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
+}
+
+//===----------------------------------------------------------------------===//
+// FuncImportOp
+//===----------------------------------------------------------------------===//
+
+void FuncImportOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, StringRef symbol,
+ StringRef moduleName, StringRef importName,
+ FunctionType type) {
+ odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
+ odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
+ odsState.addAttribute("moduleName", odsBuilder.getStringAttr(moduleName));
+ odsState.addAttribute("importName", odsBuilder.getStringAttr(importName));
+ odsState.addAttribute("type", TypeAttr::get(type));
----------------
joker-eph wrote:
The comment means to not populate the `odsState` directly, instead you would call another `build(....)` with the appropriate parameters.
https://github.com/llvm/llvm-project/pull/149233
More information about the Mlir-commits
mailing list