[Mlir-commits] [mlir] [MLIR][Wasm] Introduce the WasmSSA MLIR dialect (PR #149233)
Ferdinand Lemaire
llvmlistbot at llvm.org
Sun Jul 27 18:59:44 PDT 2025
https://github.com/flemairen6 updated https://github.com/llvm/llvm-project/pull/149233
>From d8b21ad8f276fe085197abf67ff93a6365391613 Mon Sep 17 00:00:00 2001
From: Luc Forget <dev at alias.lforget.fr>
Date: Mon, 30 Jun 2025 15:10:12 +0200
Subject: [PATCH 01/11] [mlir][wasm] Introduce the MLIR wasm dialect
This dialect is an SSA representation of a WebAssembly program.
---------
Co-authored-by: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Co-authored-by: Jessica Paquette <jessica.paquette at woven-planet.global>
---
mlir/include/mlir/Dialect/CMakeLists.txt | 1 +
.../Dialect/WebAssemblySSA/CMakeLists.txt | 1 +
.../Dialect/WebAssemblySSA/IR/CMakeLists.txt | 13 +
.../WebAssemblySSA/IR/WebAssemblySSA.h | 55 ++
.../WebAssemblySSA/IR/WebAssemblySSABase.td | 25 +
.../IR/WebAssemblySSAInterfaces.h | 28 +
.../IR/WebAssemblySSAInterfaces.td | 186 +++++
.../WebAssemblySSA/IR/WebAssemblySSAOps.td | 674 ++++++++++++++++++
.../WebAssemblySSA/IR/WebAssemblySSATypes.td | 86 +++
mlir/include/mlir/InitAllDialects.h | 2 +
mlir/lib/Dialect/CMakeLists.txt | 1 +
.../lib/Dialect/WebAssemblySSA/CMakeLists.txt | 1 +
.../Dialect/WebAssemblySSA/IR/CMakeLists.txt | 24 +
.../IR/WebAssemblySSADialect.cpp | 38 +
.../IR/WebAssemblySSAInterfaces.cpp | 61 ++
.../WebAssemblySSA/IR/WebAssemblySSAOps.cpp | 510 +++++++++++++
.../WebAssemblySSA/IR/WebAssemblySSATypes.cpp | 36 +
.../WebAssemblySSA/custom_parser/global.mlir | 44 ++
.../custom_parser/global_illegal.mlir | 23 +
.../WebAssemblySSA/custom_parser/import.mlir | 17 +
.../WebAssemblySSA/custom_parser/local.mlir | 45 ++
21 files changed, 1871 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/WebAssemblySSA/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/WebAssemblySSA/IR/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h
create mode 100644 mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSABase.td
create mode 100644 mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h
create mode 100644 mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
create mode 100644 mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
create mode 100644 mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td
create mode 100644 mlir/lib/Dialect/WebAssemblySSA/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/WebAssemblySSA/IR/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSADialect.cpp
create mode 100644 mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp
create mode 100644 mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp
create mode 100644 mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.cpp
create mode 100644 mlir/test/Dialect/WebAssemblySSA/custom_parser/global.mlir
create mode 100644 mlir/test/Dialect/WebAssemblySSA/custom_parser/global_illegal.mlir
create mode 100644 mlir/test/Dialect/WebAssemblySSA/custom_parser/import.mlir
create mode 100644 mlir/test/Dialect/WebAssemblySSA/custom_parser/local.mlir
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 56dc97282fa4a..eb6075ac76c85 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -41,5 +41,6 @@ add_subdirectory(Transform)
add_subdirectory(UB)
add_subdirectory(Utils)
add_subdirectory(Vector)
+add_subdirectory(WebAssemblySSA)
add_subdirectory(X86Vector)
add_subdirectory(XeGPU)
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/CMakeLists.txt b/mlir/include/mlir/Dialect/WebAssemblySSA/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..fa41a0caabf91
--- /dev/null
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+set(LLVM_TARGET_DEFINITIONS WebAssemblySSATypes.td)
+mlir_tablegen(WebAssemblySSATypeConstraints.h.inc -gen-type-constraint-decls)
+mlir_tablegen(WebAssemblySSATypeConstraints.cpp.inc -gen-type-constraint-defs)
+
+set (LLVM_TARGET_DEFINITIONS WebAssemblySSAInterfaces.td)
+mlir_tablegen(WebAssemblySSAInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(WebAssemblySSAInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRWebAssemblySSAInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS WebAssemblySSAOps.td)
+
+add_mlir_dialect(WebAssemblySSAOps wasmssa)
+add_mlir_doc(WebAssemblySSAOps WebAssemblySSAOps Dialects/ -gen-dialect-doc)
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h
new file mode 100644
index 0000000000000..816f7ef008d4a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h
@@ -0,0 +1,55 @@
+//===- WebAssemblySSA.h - WebAssemblySSA dialect ------------------*- C++-*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSA_H_
+#define MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSA_H_
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/Dialect.h"
+
+//===----------------------------------------------------------------------===//
+// WebAssemblyDialect
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOpsDialect.h.inc"
+
+//===----------------------------------------------------------------------===//
+// WebAssembly Dialect Types
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOpsTypes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// WebAssembly Interfaces
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// WebAssembly Dialect Operations
+//===----------------------------------------------------------------------===//
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+
+//===----------------------------------------------------------------------===//
+// WebAssembly Constraints
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace wasmssa {
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypeConstraints.h.inc"
+}
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.h.inc"
+
+#endif // MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSA_H_
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSABase.td b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSABase.td
new file mode 100644
index 0000000000000..cdc4d4864344f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSABase.td
@@ -0,0 +1,25 @@
+//===- WebAssemblySSABase.td - Base defs for wasmssa dialect -*- tablegen -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef WEBASSEMBLYSSA_BASE
+#define WEBASSEMBLYSSA_BASE
+
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/OpBase.td"
+
+def WasmSSA_Dialect : Dialect {
+ let name = "wasmssa";
+ let cppNamespace = "::mlir::wasmssa";
+ let description = [{
+ The `wasmssa` dialect is intended to represent WebAssembly
+ modules in SSA form for easier manipulation.
+ }];
+ let useDefaultTypePrinterParser = true;
+}
+
+#endif //WEBASSEMBLYSSA_BASE
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h
new file mode 100644
index 0000000000000..03c4021b1421b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h
@@ -0,0 +1,28 @@
+//===- WebAssemblySSAInterfaces.h - WebAssemblySSA Interfaces ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines op interfaces for the WebAssemblySSA dialect in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSAINTERFACES_H_
+#define MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSAINTERFACES_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir::wasmssa {
+namespace detail {
+LogicalResult verifyConstantExpressionInterface(Operation *op);
+LogicalResult verifyWasmSSALabelBranchingInterface(Operation *op);
+} // namespace detail
+
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h.inc"
+} // namespace mlir::wasmssa
+
+#endif // MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSAINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
new file mode 100644
index 0000000000000..7857556871c3f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
@@ -0,0 +1,186 @@
+//===-- WebAssemblySSAInterfaces.td - WebAssemblySSA Interfaces -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines interfaces for the WebAssemblySSA dialect in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef WEBASSEMBLYSSA_INTERFACES
+#define WEBASSEMBLYSSA_INTERFACES
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/BuiltinAttributes.td"
+
+def WasmSSALabelLevelInterface : OpInterface<"WasmSSALabelLevelInterface"> {
+ let description = [{
+ Operation that defines one level of nesting for wasm branching.
+ These operation region can be targeted by branch instructions.
+ }];
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns the target block address",
+ /*returnType=*/ "::mlir::Block*",
+ /*methodName=*/ "getLabelTarget",
+ /*args=*/ (ins)
+ >
+ ];
+}
+
+def WasmSSALabelBranchingInterface : OpInterface<"WasmSSALabelBranchingInterface"> {
+ let description = [{
+ Wasm operation that targets a label for a jump.
+ }];
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns the number of context to break from",
+ /*returnType=*/ "size_t",
+ /*methodName=*/ "getExitLevel",
+ /*args=*/ (ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns the destination of this operation",
+ /*returnType=*/ "WasmSSALabelLevelInterface",
+ /*methodName=*/ "getTargetOp",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{
+ return *WasmSSALabelBranchingInterface::getTargetOpFromBlock($_op.getOperation()->getBlock(), $_op.getExitLevel());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Return the target control flow ops that defined the label of this operation",
+ /*returnType=*/ "::mlir::Block*",
+ /*methodName=*/ "getTarget",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = mlir::cast<WasmSSALabelBranchingInterface>(this->getOperation());
+ return op.getTargetOp().getLabelTarget();
+ }]
+ >
+ ];
+ let extraClassDeclaration = [{
+ static ::llvm::FailureOr<WasmSSALabelLevelInterface> getTargetOpFromBlock(::mlir::Block *block, uint32_t level);
+ }];
+ let verify = [{return verifyWasmSSALabelBranchingInterface($_op);}];
+}
+
+def WasmSSAImportOpInterface : OpInterface<"WasmSSAImportOpInterface"> {
+ let description = [{
+ Operation that imports a symbol from an external wasm module;
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns the module name for the import",
+ /*returnType=*/ "::llvm::StringRef",
+ /*methodName=*/ "getModuleName",
+ /*args=*/ (ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns the import name for the import",
+ /*returnType=*/ "::llvm::StringRef",
+ /*methodName=*/ "getImportName",
+ /*args=*/ (ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns the wasm index based symbol of the op",
+ /*returnType=*/ "::mlir::StringAttr",
+ /*methodName=*/ "getSymbolName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = mlir::cast<ConcreteOp>(this->getOperation());
+ return op.getSymNameAttr();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns the qualified name of the import",
+ /*returnType=*/ "std::string",
+ /*methodName=*/ "getQualifiedImportName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{
+ return ($_op.getModuleName() + llvm::Twine{"::"} + $_op.getImportName()).str();
+ }]
+ >,
+ ];
+}
+
+def WasmSSAConstantExpressionInitializerInterface :
+ OpInterface<"WasmSSAConstantExpressionInitializerInterface"> {
+ let description = [{
+ Operation that must be constant initialized. This
+ interface adds a verifier that checks that all ops
+ within the initializer region are "constant expressions"
+ as defined by the WASM standard.
+ }];
+
+ let verify = [{ return detail::verifyConstantExpressionInterface($_op); }];
+}
+
+def WasmSSAConstantExprCheckInterface :
+ OpInterface<"WasmSSAConstantExprCheckInterface"> {
+ let description = [{
+ Base interface for operations that can be used in a Wasm Constant Expression.
+ It shouldn't be used directly, use one of the derived instead.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ [{
+ Returns success if the current operation is valid in a constant expression context.
+ }],
+ /*returnType=*/ "::mlir::LogicalResult",
+ /*methodName=*/ "isValidInConstantExpr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{
+ return $_op.verifyConstantExprValidity();
+ }]
+ >
+ ];
+}
+
+def WasmSSAContextuallyConstantExprInterface :
+ OpInterface<"WasmSSAContextuallyConstantExprInterface", [WasmSSAConstantExprCheckInterface]> {
+ let description = [{
+ Base interface for operations that can be used in a Wasm Constant Expression
+ depending on the context.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ [{
+ Returns success if the current operation is valid in a constant expression context.
+ }],
+ /*returnType=*/ "::mlir::LogicalResult",
+ /*methodName=*/ "verifyConstantExprValidity",
+ /*args=*/ (ins)
+ >
+ ];
+}
+
+def WasmSSAConstantExprInterface :
+ OpInterface<"WasmSSAConstantExprInterface", [WasmSSAConstantExprCheckInterface]> {
+ let description = [{
+ Base interface for operations that can always be used in a Wasm Constant Expression.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ [{
+ Returns success if the current operation is valid in a constant expression context.
+ }],
+ /*returnType=*/ "::mlir::LogicalResult",
+ /*methodName=*/ "verifyConstantExprValidity",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*DefaultImplementation=*/ [{ return success(); }]
+ >
+ ];
+}
+
+#endif // WEBASSEMBLYSSA_INTERFACES
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
new file mode 100644
index 0000000000000..9e370920f3173
--- /dev/null
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
@@ -0,0 +1,674 @@
+//===- WebAssemblySSAOps.td - WebAssemblySSA op definitions -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef WEBASSEMBLYSSA_OPS
+#define WEBASSEMBLYSSA_OPS
+
+
+include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSABase.td"
+include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td"
+include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td"
+
+include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/SymbolInterfaces.td"
+
+class WasmSSA_Op<string mnemonic, list<Trait> traits = []> :
+ Op<WasmSSA_Dialect, mnemonic, traits>;
+
+class WasmSSA_BlockLikeOp<string mnemonic, string summaryStr> :
+ WasmSSA_Op<mnemonic, [Terminator, DeclareOpInterfaceMethods<WasmSSALabelLevelInterface>]> {
+ let summary = summaryStr;
+ let arguments = (ins Variadic<WasmSSA_ValType>: $inputs);
+ let regions = (region AnyRegion: $body);
+ let successors = (successor AnySuccessor: $target);
+ let extraClassDeclaration = [{
+ ::mlir::Block* createBlock() {
+ auto &block = getBody().emplaceBlock();
+ for (auto input : getInputs())
+ block.addArgument(input.getType(), input.getLoc());
+ return █
+ }
+ }];
+ let assemblyFormat = "(`(`$inputs^`)` `:` type($inputs))? attr-dict `:` $body `>` $target";
+}
+
+def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<"block", "Create a nesting level"> {}
+
+def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<"loop", "Create a nesting level similar to Block Op, except that it has itself as a successor."> {}
+
+def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
+ DeclareOpInterfaceMethods<WasmSSALabelBranchingInterface>]> {
+ let summary = "Return from the current block";
+ let arguments = (ins Variadic<WasmSSA_ValType>: $inputs);
+ let extraClassDeclaration = [{
+ ::mlir::Block* getTarget();
+ }];
+ let assemblyFormat = "($inputs^ `:` type($inputs))? attr-dict";
+}
+
+def WasmSSA_BranchIfOp : WasmSSA_Op<"branch_if", [
+ Terminator,
+ DeclareOpInterfaceMethods<WasmSSALabelBranchingInterface>]> {
+ let summary = "Jump to target level if condition has non-zero value";
+ let arguments = (ins I32: $condition,
+ UI32Attr: $exitLevel,
+ Variadic<WasmSSA_ValType>: $inputs);
+ let successors = (successor AnySuccessor: $elseSuccessor);
+ let assemblyFormat = "$condition `to` `level` $exitLevel (`with` `args` `(`$inputs^ `:` type($inputs)`)`)? `else` $elseSuccessor attr-dict";
+}
+
+def WasmSSA_ConstOp : WasmSSA_Op<"const", [
+ AllTypesMatch<["value", "result"]>, WasmSSAConstantExprInterface]> {
+ let summary = "Operator that represents a constant value";
+ let arguments = (ins TypedAttrInterface: $value);
+ let results = (outs WasmSSA_NumericType: $result);
+ let assemblyFormat = "$value attr-dict";
+}
+
+def WasmSSA_FuncOp : WasmSSA_Op<"func", [
+ AffineScope, AutomaticAllocationScope,
+ DeclareOpInterfaceMethods<FunctionOpInterface>,
+ IsolatedFromAbove,
+ Symbol]> {
+ let description = [{
+ Represents a Wasm function definition.
+
+ In Wasm function, locals and function arguments are interchangeable.
+ They are for instance both accessed using `local.get` instruction.
+
+ On the other hand, a function type is defined as a pair of tuples of Wasm value types.
+ To model this, the wasm.func operation has:
+
+ - A function type that represents the corresponding wasm type (tuples of value types)
+
+ - Arguments of the entry block of type `!wasm<local T>`, with T the corresponding type
+ in the function type.
+ }];
+ let arguments = (ins SymbolNameAttr: $sym_name,
+ WasmSSA_FuncTypeAttr: $functionType,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs,
+ DefaultValuedAttr<StrAttr, "\"nested\"">:$sym_visibility);
+ let regions = (region AnyRegion: $body);
+ let extraClassDeclaration = [{
+
+ /// Create the entry block for the function with parameters wrapped in local ref.
+ ::mlir::Block* addEntryBlock();
+
+ //===------------------------------------------------------------------===//
+ // FunctionOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// Returns the region on the current operation that is callable. This may
+ /// return null in the case of an external callable object, e.g. an external
+ /// function.
+ ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
+
+ ::mlir::LogicalResult verifyBody();
+
+ /// Returns the argument types of this function.
+ ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+ /// Returns the result types of this function.
+ ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+ }];
+
+ let builders = [
+ OpBuilder<(ins "llvm::StringRef":$symbol, "FunctionType":$funcType )>
+ ];
+ let hasCustomAssemblyFormat = 1;
+}
+
+def WasmSSA_FuncCallOp : WasmSSA_Op<"call"> {
+ let summary = "Calling a wasm function";
+ let arguments = (ins FlatSymbolRefAttr: $callee,
+ Variadic<WasmSSA_ValType>: $operands);
+ let results = (outs Variadic<WasmSSA_ValType>: $results);
+ let assemblyFormat = "$callee (`(`$operands^`)`)? attr-dict `:` functional-type($operands, $results)";
+ let description = [{
+ Emits a call to callee.
+ }];
+}
+
+def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
+ Symbol,
+ CallableOpInterface,
+ WasmSSAImportOpInterface]> {
+ let summary = "Importing a function variable";
+ let arguments = (ins SymbolNameAttr: $sym_name,
+ StrAttr: $moduleName,
+ StrAttr: $importName,
+ WasmSSA_FuncTypeAttr: $type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let extraClassDeclaration = [{
+ bool isDeclaration() const { return true; }
+
+ Region *getCallableRegion() { return nullptr; }
+
+ llvm::ArrayRef<Type> getArgumentTypes() {
+ return getType().getInputs();
+ }
+
+ llvm::ArrayRef<Type> getResultTypes() {
+ return getType().getResults();
+ }
+ }];
+ let builders = [
+ OpBuilder<(ins "StringRef":$symbol,
+ "StringRef":$moduleName,
+ "StringRef":$importName,
+ "FunctionType": $type)>
+ ];
+ let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
+}
+
+def WasmSSA_GlobalOp : WasmSSA_Op<"global", [
+ AffineScope, AutomaticAllocationScope,
+ IsolatedFromAbove, Symbol, WasmSSAConstantExpressionInitializerInterface]> {
+ let summary= "WebAssembly global value";
+ let arguments = (ins SymbolNameAttr: $sym_name,
+ WasmSSA_ValTypeAttr: $type,
+ UnitAttr: $isMutable,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let description = [{
+ WebAssembly global variable.
+ Body contains the initialization instructions for the variable value.
+ }];
+ let regions = (region AnyRegion: $initializer);
+
+ let builders = [
+ OpBuilder<(ins "StringRef":$symbol,
+ "Type": $type,
+ "bool": $isMutable)>
+ ];
+ let hasCustomAssemblyFormat = 1;
+}
+
+def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [
+ Symbol,
+ WasmSSAImportOpInterface]> {
+ let summary = "Importing a global variable";
+ let arguments = (ins SymbolNameAttr: $sym_name,
+ StrAttr: $moduleName,
+ StrAttr: $importName,
+ WasmSSA_ValTypeAttr: $type,
+ UnitAttr: $isMutable,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let extraClassDeclaration = [{
+ bool isDeclaration() const { return true; }
+ }];
+ let builders = [
+ OpBuilder<(ins "StringRef":$symbol,
+ "StringRef":$moduleName,
+ "StringRef":$importName,
+ "Type": $type,
+ "bool": $isMutable)>
+ ];
+ let hasCustomAssemblyFormat = 1;
+}
+
+def WasmSSA_GlobalGetOp : WasmSSA_Op<"global_get", [DeclareOpInterfaceMethods<WasmSSAContextuallyConstantExprInterface>]> {
+ let summary = "Returns the value of the global passed as argument.";
+ let arguments = (ins FlatSymbolRefAttr: $global);
+ let results = (outs WasmSSA_ValType: $global_val);
+ let assemblyFormat = "$global attr-dict `:` type($global_val)";
+}
+
+def WasmSSA_IfOp : WasmSSA_Op<"if", [Terminator,
+ DeclareOpInterfaceMethods<WasmSSALabelLevelInterface>]> {
+ let summary = "Execute the if region if condition value is nonzero, the else region otherwise.";
+ let arguments = (ins I32:$condition, Variadic<WasmSSA_ValType>: $inputs);
+ let regions = (region AnyRegion: $if, AnyRegion: $else);
+ let successors = (successor AnySuccessor: $target);
+ let extraClassDeclaration = [{
+ private:
+ inline ::mlir::Block* createBlock(::mlir::Region& region) {
+ assert(region.empty() && "Creating entry block on non empty region");
+ assert(region.getParentOp() == this->getOperation() &&
+ "Creating block for region that isn't part of the current op");
+ auto &block = region.emplaceBlock();
+ for (auto input : getInputs())
+ block.addArgument(input.getType(), input.getLoc());
+ return █
+ }
+
+ public:
+ ::mlir::Block* createIfBlock() {
+ return createBlock(getIf());
+ }
+ ::mlir::Block* createElseBlock() {
+ return createBlock(getElse());
+ }
+ }];
+}
+
+def WasmSSA_LocalOp : WasmSSA_Op<"local", [
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Declaration of local variable";
+ let arguments = (ins WasmSSA_ValTypeAttr: $type);
+ let results = (outs WasmSSA_LocalRef: $result);
+ let assemblyFormat = "`of` `type` $type attr-dict";
+}
+
+def WasmSSA_LocalGetOp : WasmSSA_Op<"local_get", [
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Set local to value and return the operand.";
+ let arguments = (ins WasmSSA_LocalRef: $localVar);
+ let results = (outs WasmSSA_ValType: $result);
+ let assemblyFormat = "$localVar `:` type($localVar) attr-dict";
+ let hasVerifier = 1;
+}
+
+def WasmSSA_LocalSetOp : WasmSSA_Op<"local_set"> {
+ let summary = "Set local to given value";
+ let arguments = (ins WasmSSA_LocalRef: $localVar,
+ WasmSSA_ValType: $value);
+ let hasVerifier = 1;
+ let assemblyFormat = "$localVar `:` type($localVar) `to` $value `:` type($value) attr-dict";
+}
+
+def WasmSSA_LocalTeeOp : WasmSSA_Op<"local_tee", [
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Set local to value and return the operand.";
+ let arguments = (ins WasmSSA_LocalRef: $localVar,
+ WasmSSA_ValType: $value);
+ let results = (outs WasmSSA_ValType: $result);
+ let hasVerifier = 1;
+ let assemblyFormat = "$localVar `:` type($localVar) `to` $value `:` type($value) attr-dict";
+}
+
+def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> {
+ let summary= "WebAssembly memory definition";
+ let arguments = (ins SymbolNameAttr: $sym_name,
+ WasmSSA_LimitTypeAttr: $limits,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let builders = [
+ OpBuilder<(ins
+ "llvm::StringRef":$symbol,
+ "LimitType":$limit)>
+ ];
+}
+
+def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, WasmSSAImportOpInterface]> {
+ let summary = "Importing a memory";
+ let arguments = (ins SymbolNameAttr: $sym_name,
+ StrAttr: $moduleName,
+ StrAttr: $importName,
+ WasmSSA_LimitTypeAttr: $limits,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let extraClassDeclaration = [{
+ bool isDeclaration() const { return true; }
+ }];
+ let builders = [OpBuilder<(ins
+ "llvm::StringRef":$symbol,
+ "llvm::StringRef":$moduleName,
+ "llvm::StringRef":$importName,
+ "wasmssa::LimitType":$limits)>];
+ let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
+}
+
+def WasmSSA_TableOp : WasmSSA_Op<"table", [Symbol]> {
+ let summary= "WebAssembly table value";
+ let arguments = (ins SymbolNameAttr: $sym_name,
+ WasmSSA_TableTypeAttr: $type,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let builders = [OpBuilder<(ins
+ "llvm::StringRef":$symbol,
+ "wasmssa::TableType":$type)>];
+}
+
+def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, WasmSSAImportOpInterface]> {
+ let summary = "Importing a table";
+ let arguments = (ins SymbolNameAttr: $sym_name,
+ StrAttr: $moduleName,
+ StrAttr: $importName,
+ WasmSSA_TableTypeAttr: $type,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let extraClassDeclaration = [{
+ bool isDeclaration() const { return true; }
+ }];
+ let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
+ let builders = [OpBuilder<(ins
+ "llvm::StringRef":$symbol,
+ "llvm::StringRef":$moduleName,
+ "llvm::StringRef":$importName,
+ "wasmssa::TableType":$type)>];
+}
+
+def WasmSSA_ReturnOp : WasmSSA_Op<"return", [Terminator]> {
+ let summary = "Return from the current function frame";
+ let arguments = (ins Variadic<WasmSSA_ValType>: $operands);
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+ let builders = [
+ OpBuilder<(ins)>
+ ];
+}
+
+// ---- Numeric ops
+
+class WasmSSA_BinaryNumericalOp<string mnemonic, string summaryStr,
+ list<Type> validOpTypes> :
+ WasmSSA_Op<mnemonic, [AllTypesMatch<["lhs", "rhs", "result"]>]> {
+ let summary = summaryStr;
+ let arguments = (ins AnyTypeOf<validOpTypes>:$lhs, AnyTypeOf<validOpTypes>:$rhs);
+ let results = (outs AnyTypeOf<validOpTypes>:$result);
+ let assemblyFormat = "$lhs $rhs `:` type($lhs) attr-dict";
+}
+
+def WasmSSA_AddOp : WasmSSA_BinaryNumericalOp<"add",
+ "Sum two values",
+ [WasmSSA_NumericType]>{}
+
+def WasmSSA_AndOp : WasmSSA_BinaryNumericalOp<"and",
+ "Compute the bitwise AND between two values",
+ [WasmSSA_NumericType]>{}
+
+def WasmSSA_DivOp : WasmSSA_BinaryNumericalOp<"div",
+ "Division between floating point values",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_DivUIOp : WasmSSA_BinaryNumericalOp<"div_ui",
+ "Divide values interpreted as unsigned int",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_DivSIOp : WasmSSA_BinaryNumericalOp<"div_si",
+ "Divide values interpreted as signed int",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_MulOp : WasmSSA_BinaryNumericalOp<"mul",
+ "Multiply two values",
+ [WasmSSA_NumericType]>{}
+
+def WasmSSA_OrOp : WasmSSA_BinaryNumericalOp<"or",
+ "Compute the bitwise OR of two values",
+ [WasmSSA_NumericType]>{}
+
+def WasmSSA_SubOp : WasmSSA_BinaryNumericalOp<"sub",
+ "Subtract two values",
+ [WasmSSA_NumericType]>{}
+
+def WasmSSA_RemUIOp : WasmSSA_BinaryNumericalOp<"rem_ui",
+ "Calculate the remainder of dividing two integer values as an unsigned integer",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_RemSIOp : WasmSSA_BinaryNumericalOp<"rem_si",
+ "Calculate the remainder of dividing two integer values as signed integer",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_XOrOp : WasmSSA_BinaryNumericalOp<"xor",
+ "Compute the bitwise XOR of two values",
+ [WasmSSA_NumericType]>{}
+
+def WasmSSA_MinOp : WasmSSA_BinaryNumericalOp<"min",
+ "Compute the minimum of two floating point values.",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_MaxOp : WasmSSA_BinaryNumericalOp<"max",
+ "Compute the minimum of two floating point values.",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_CopySignOp : WasmSSA_BinaryNumericalOp<"copysign",
+ "Copy sign from one floating point vaue to the other.",
+ [WasmSSA_FPType]>{}
+
+class WasmSSA_BinaryComparisonOp<string mnemonic, string summaryStr,
+ list<Type> validOpTypes> :
+ WasmSSA_Op<mnemonic, [AllTypesMatch<["lhs", "rhs"]>]> {
+ let summary = summaryStr;
+ let arguments = (ins AnyTypeOf<validOpTypes>:$lhs, AnyTypeOf<validOpTypes>:$rhs);
+ let results = (outs I32:$result);
+ let assemblyFormat = "$lhs $rhs `:` type($lhs) `->` type($result) attr-dict";
+}
+
+def WasmSSA_EqOp : WasmSSA_BinaryComparisonOp<"eq",
+ "Check if two values are equal",
+ [WasmSSA_NumericType]>{}
+
+def WasmSSA_NeOp : WasmSSA_BinaryComparisonOp<"ne",
+ "Check if two values are different",
+ [WasmSSA_NumericType]>{}
+
+def WasmSSA_LtSIOp : WasmSSA_BinaryComparisonOp<"lt_si",
+ "Check if a signed integer value is less than another",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_LtUIOp : WasmSSA_BinaryComparisonOp<"lt_ui",
+ "Check if an unsigned integer value is less than another",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_LeSIOp : WasmSSA_BinaryComparisonOp<"le_si",
+ "Check if a signed integer value is less or equal to another",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_LeUIOp : WasmSSA_BinaryComparisonOp<"le_ui",
+ "Check if an unsigned integer value is less or equal to another",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_GtSIOp : WasmSSA_BinaryComparisonOp<"gt_si",
+ "Check if a signed integer value is greater than another",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_GtUIOp : WasmSSA_BinaryComparisonOp<"gt_ui",
+ "Check if an unsigned integer value is greater than another",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_GeSIOp : WasmSSA_BinaryComparisonOp<"ge_si",
+ "Check if a signed integer value is greater or equal to another",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_GeUIOp : WasmSSA_BinaryComparisonOp<"ge_ui",
+ "Check if an unsigned integer value is greater or equal to another",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_LtOp : WasmSSA_BinaryComparisonOp<"lt",
+ "Check if a float value is less than another",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_LeOp : WasmSSA_BinaryComparisonOp<"le",
+ "Check if a float value is less or equal to another",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_GtOp : WasmSSA_BinaryComparisonOp<"gt",
+ "Check if a float value is greater than another",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_GeOp : WasmSSA_BinaryComparisonOp<"ge",
+ "Check if a float value is greater or equal to another",
+ [WasmSSA_FPType]>{}
+
+// Integer shift and rotate operations.
+class WasmSSA_ShiftRotateOp<string mnemonic, string summaryStr> :
+ WasmSSA_Op<mnemonic, [AllTypesMatch<["val", "bits", "result"]>]> {
+ let summary = summaryStr;
+ let arguments = (ins WasmSSA_IntegerType:$val, WasmSSA_IntegerType:$bits);
+ let results = (outs WasmSSA_IntegerType:$result);
+ let assemblyFormat = "$val `by` $bits `bits` `:` type($val) attr-dict";
+}
+
+def WasmSSA_ShLOp : WasmSSA_ShiftRotateOp<"shl",
+ [{Consume an integer and an integer shift amount. The first
+ integer shall be shifted left by N bits, where N is the value of the second
+ integer.}]
+ >{}
+
+def WasmSSA_ShRSOp : WasmSSA_ShiftRotateOp<"shr_s",
+ [{Arithmetic right shift.
+
+ Consume an integer and an integer shift amount. The first
+ integer shall be shifted right by N bits, where N is the value of the
+ second integer.
+
+ Vacated bits on the left shall be filled with the sign bit.}]
+ >{}
+
+def WasmSSA_ShRUOp : WasmSSA_ShiftRotateOp<"shr_u",
+ [{Logical right shift.
+
+ Consume an integer, and an integer shift amount. The first
+ integer shall be shifted right by N bits, where N is the value of the
+ second integer.
+
+ Vacated bits on the left shall be filled with zeroes.}]
+ >{}
+
+def WasmSSA_RotlOp : WasmSSA_ShiftRotateOp<"rotl",
+ [{Rotate left.
+
+ Consume an integer and an integer rotate. The first
+ integer shall be rotated left by N bits, where N is the value of the
+ second integer.}]
+ >{}
+
+def WasmSSA_RotrOp : WasmSSA_ShiftRotateOp<"rotr",
+ [{Rotate right.
+
+ Consume an integer, and an integer rotate. The first
+ integer shall be rotated right by N bits, where N is the value of the
+ second integer.}]
+ >{}
+
+class WasmSSA_ConversionOp<string mnemonic, string summaryStr,
+ list<Type> ValidInputTypes,
+ list<Type> ValidOutputTypes> :
+ WasmSSA_Op<mnemonic> {
+ let summary = summaryStr;
+ let arguments = (ins AnyTypeOf<ValidInputTypes>:$input);
+ let results = (outs AnyTypeOf<ValidOutputTypes>:$result);
+ let assemblyFormat = "$input `:` type($input) `to` type($result) attr-dict";
+}
+
+def WasmSSA_ConvertUOp : WasmSSA_ConversionOp<"convert_u",
+ [{Convert integer, interpreted as binary encoded positive value, to floating-point value.
+
+ Consume an integer and produces a floating point value containing the rounded value of the original operand. Rounding is round to nearest, tie to even.}],
+ [WasmSSA_IntegerType],
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_ConvertSOp : WasmSSA_ConversionOp<"convert_s",
+ [{Convert integer interpreted as 2's complement signed value to floating-point value.
+
+ Consume an integer and produces a floating point value containing the rounded value of the original operand. Rounding is round to nearest, tie to even.}],
+ [WasmSSA_IntegerType],
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_DemoteOp : WasmSSA_ConversionOp<"demote",
+ "Convert a f64 value to f32",
+ [F64],
+ [F32]>{}
+
+def WasmSSA_ExtendSI32Op : WasmSSA_Op<"extend_i32_s">{
+ let summary = [{Sign extend i32 to i64.}];
+ let arguments = (ins I32:$input);
+ let results = (outs I64:$result);
+ let assemblyFormat = "$input `to` type($result) attr-dict";
+}
+
+def WasmSSA_ExtendUI32Op : WasmSSA_Op<"extend_i32_u">{
+ let summary = [{Zero extend i32 to i64.}];
+ let arguments = (ins I32:$input);
+ let results = (outs I64:$result);
+ let assemblyFormat = "$input `to` type($result) attr-dict";
+}
+
+def WasmSSA_ExtendLowBitsSOp : WasmSSA_Op<"extend", [AllTypesMatch<["input", "result"]>]> {
+ let summary = "";
+ let description = [{
+ Extend low bytes of a value to fit a given width.
+ For instance, signed extension from 8 low bits of the 32-bits integer value
+ 254 (0x000000FE) would produce the value -2 (0xFFFFFFFE).
+
+ This corresponds to the `extendnn` instruction of Wasm, which shouldn't be
+ confused with the `extend_inn` Wasm instruction, for which all input bits
+ are used and widened to wider output type.
+ In this operation, input and output types are the same.
+ }];
+ let arguments = (ins WasmSSA_IntegerType:$input, Builtin_IntegerAttr: $bitsToTake);
+ let results = (outs WasmSSA_IntegerType: $result);
+ let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
+def WasmSSA_PromoteOp : WasmSSA_ConversionOp<"promote",
+ "Get f64 representation of a f32 value.",
+ [Builtin_Float32],
+ [Builtin_Float64]>{}
+
+def WasmSSA_WrapOp : WasmSSA_ConversionOp<"wrap",
+ "Cast an i64 to i32 by using a wrapping mechanism: y = x mod 2^32",
+ [I64],
+ [I32]>{}
+// Reinterpret ops are basically all one-offs. They all have an unique,
+// type-postfixed opcode, and support exactly one input and output type.
+def WasmSSA_ReinterpretOp : WasmSSA_ConversionOp<"reinterpret",
+ [{Reinterpret the value represented by a bit vector by
+ bit-casting it to another type of same representation width.}],
+ [WasmSSA_NumericType], [WasmSSA_NumericType]>{
+ let assemblyFormat = "$input `:` type($input) `as` type($result) attr-dict";
+ let hasVerifier = 1;
+}
+
+class WasmSSA_UnaryNumericalOp<string mnemonic,
+ string summaryStr,
+ list<Type> validOpTypes> :
+ WasmSSA_Op<mnemonic, [AllTypesMatch<["src", "result"]>]> {
+ let summary = summaryStr;
+ let arguments = (ins AnyTypeOf<validOpTypes>:$src);
+ let results = (outs AnyTypeOf<validOpTypes>:$result);
+ let assemblyFormat = "$src`:` type($src) attr-dict";
+}
+
+def WasmSSA_AbsOp : WasmSSA_UnaryNumericalOp<"abs",
+ "Floating point absolute value",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_CeilOp : WasmSSA_UnaryNumericalOp<"ceil",
+ "Ceil rounding of floating point value",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_FloorOp : WasmSSA_UnaryNumericalOp<"floor",
+ "Floor rounding of floating point value",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_NegOp : WasmSSA_UnaryNumericalOp<"neg",
+ "Floating point negation",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_SqrtOp : WasmSSA_UnaryNumericalOp<"sqrt",
+ "Floating point square root",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_TruncOp : WasmSSA_UnaryNumericalOp<"trunc",
+ "Trunc of floating point value",
+ [WasmSSA_FPType]>{}
+
+def WasmSSA_CtzOp : WasmSSA_UnaryNumericalOp<"ctz",
+ "Count trailing zeroes of an integer",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_ClzOp : WasmSSA_UnaryNumericalOp<"clz",
+ "Count leading zeroes of an integer",
+ [WasmSSA_IntegerType]>{}
+
+def WasmSSA_EqzOp : WasmSSA_Op<"eqz", []> {
+ let summary = "Check if the given value is equal to zero";
+ let arguments = (ins WasmSSA_IntegerType: $input);
+ let results = (outs I32: $result);
+ let assemblyFormat = "$input`:` type($input) `->` type($result) attr-dict";
+}
+
+
+def WasmSSA_PopCntOp : WasmSSA_UnaryNumericalOp<"popcnt",
+ "Population count of an integer.",
+ [WasmSSA_IntegerType]>{}
+
+
+#endif // WEBASSEMBLYSSA_OPS
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td
new file mode 100644
index 0000000000000..1ab2196d70b34
--- /dev/null
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td
@@ -0,0 +1,86 @@
+//===- WebAssemblySSATypes.td - WebAssemblySSA types def ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef WEBASSEMBLYSSA_TYPES
+#define WEBASSEMBLYSSA_TYPES
+
+include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSABase.td"
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
+
+def WasmSSA_IntegerType : AnyTypeOf<[I32, I64]>;
+def WasmSSA_FPType: AnyTypeOf<[F32, F64]>;
+def WasmSSA_NumericType : AnyTypeOf<[WasmSSA_IntegerType, WasmSSA_FPType]>{
+ let cppFunctionName = "isWasmNumericType";
+}
+def WasmSSA_VecType : AnyTypeOf<[I128]>;
+
+class WasmSSA_Type<string name, string typeMnemonic, list<Trait> traits = []>
+ : TypeDef<WasmSSA_Dialect, name, traits> {
+ let mnemonic = typeMnemonic;
+}
+
+def WasmSSA_FuncRefType : WasmSSA_Type<"FuncRef", "funcref"> {
+ let summary = "Opaque type for function reference";
+ let assemblyFormat = "";
+}
+
+def WasmSSA_ExternRefType : WasmSSA_Type<"ExternRef", "externref"> {
+ let summary = "Opaque type for external reference";
+}
+
+def WasmSSA_RefType : AnyTypeOf<[WasmSSA_FuncRefType, WasmSSA_ExternRefType]> {
+ let cppFunctionName = "isWasmRefType";
+}
+
+def WasmSSA_ValType : AnyTypeOf<[WasmSSA_NumericType, WasmSSA_VecType, WasmSSA_RefType]> {
+ let cppFunctionName = "isWasmValueType";
+}
+
+def WasmSSA_ResultType : TupleOf<[WasmSSA_ValType]>;
+
+def WasmSSA_FuncType : TypeAlias<FunctionType>;
+
+def WasmSSA_LimitType : WasmSSA_Type<"Limit", "limit"> {
+ let summary = "Wasm limit type";
+ let parameters = (ins "uint32_t":$min,
+ "std::optional<uint32_t>":$max);
+ let hasCustomAssemblyFormat = 1;
+}
+
+def WasmSSA_LocalRef : WasmSSA_Type<"LocalRef", "local"> {
+ let summary = "Type of a local variable";
+ let parameters = (ins WasmSSA_ValType: $elementType);
+ let assemblyFormat = "`ref` `to` $elementType";
+ let builders = [TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{
+ return get(typeParam.getContext(), typeParam);
+ }]>,];
+}
+
+def WasmSSA_TableType : WasmSSA_Type<"Table", "tabletype"> {
+ let summary = "Wasm table type";
+ let parameters = (ins WasmSSA_RefType:$reference,
+ WasmSSA_LimitType:$limit);
+ let assemblyFormat = "$reference $limit";
+}
+
+def WasmSSA_FuncTypeAttr : TypeAttrOf<WasmSSA_FuncType>;
+def WasmSSA_LimitTypeAttr : TypeAttrOf<WasmSSA_LimitType>;
+def WasmSSA_TableTypeAttr : TypeAttrOf<WasmSSA_TableType>;
+def WasmSSA_ValTypeAttr : TypeAttrOf<WasmSSA_ValType>;
+
+def WasmSSA_IntegerAttr : AnyAttrOf<[I32Attr, I64Attr]>;
+def WasmSSA_FPAttr : AnyAttrOf<[F32Attr, F64Attr]>;
+def WasmSSA_NumericAttr : AnyAttrOf<[WasmSSA_IntegerAttr, WasmSSA_FPAttr]>;
+def WasmSSA_VecAttr : TypedSignlessIntegerAttrBase<
+ I128, "::llvm::APInt", "128 bits signless integer attribute">;
+
+#endif // WEBASSEMBLYSSA_TYPES
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c6fcf1a0d510b..5f3676a25d561 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -96,6 +96,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Dialect.h"
@@ -152,6 +153,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
transform::TransformDialect,
ub::UBDialect,
vector::VectorDialect,
+ wasmssa::WasmSSADialect,
x86vector::X86VectorDialect,
xegpu::XeGPUDialect,
xevm::XeVMDialect>();
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 3cc52ebc0a8d9..b0403783d0752 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(Transform)
add_subdirectory(UB)
add_subdirectory(Utils)
add_subdirectory(Vector)
+add_subdirectory(WebAssemblySSA)
add_subdirectory(X86Vector)
add_subdirectory(XeGPU)
diff --git a/mlir/lib/Dialect/WebAssemblySSA/CMakeLists.txt b/mlir/lib/Dialect/WebAssemblySSA/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/lib/Dialect/WebAssemblySSA/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/CMakeLists.txt b/mlir/lib/Dialect/WebAssemblySSA/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..b106b8b7c2264
--- /dev/null
+++ b/mlir/lib/Dialect/WebAssemblySSA/IR/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_mlir_dialect_library(MLIRWasmSSADialect
+ WebAssemblySSAOps.cpp
+ WebAssemblySSADialect.cpp
+ WebAssemblySSAInterfaces.cpp
+ WebAssemblySSATypes.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/WebAssemblySSA
+
+ DEPENDS
+ MLIRWebAssemblySSAOpsIncGen
+ MLIRWebAssemblySSAInterfacesIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRCastInterfaces
+ MLIRDataLayoutInterfaces
+ MLIRDialect
+ MLIRInferTypeOpInterface
+ MLIRIR
+ MLIRSupport
+
+ PRIVATE
+ MLIRFunctionInterfaces
+ )
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSADialect.cpp b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSADialect.cpp
new file mode 100644
index 0000000000000..a37e77256d970
--- /dev/null
+++ b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSADialect.cpp
@@ -0,0 +1,38 @@
+//===- WebAssemblyDialect.cpp - MLIR WebAssembly dialect implementation ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::wasmssa;
+
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOpsDialect.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// TableGen'd types definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOpsTypes.cpp.inc"
+
+void wasmssa::WasmSSADialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp.inc"
+ >();
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOpsTypes.cpp.inc"
+ >();
+}
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp
new file mode 100644
index 0000000000000..0b43f2f671062
--- /dev/null
+++ b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp
@@ -0,0 +1,61 @@
+//===- WebAssemblySSAInterfaces.cpp - WebAssemblySSA Interfaces -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines op interfaces for the WebAssemblySSA dialect in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h"
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir::wasmssa {
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp.inc"
+
+namespace detail {
+LogicalResult verifyWasmSSALabelBranchingInterface(Operation *op) {
+ auto branchInterface = dyn_cast<WasmSSALabelBranchingInterface>(op);
+ auto res = WasmSSALabelBranchingInterface::getTargetOpFromBlock(
+ op->getBlock(), branchInterface.getExitLevel());
+ return success(succeeded(res));
+}
+
+LogicalResult verifyConstantExpressionInterface(Operation *op) {
+ Region &initializerRegion = op->getRegion(0);
+ auto resultState =
+ initializerRegion.walk([&](Operation *currentOp) -> WalkResult {
+ if (isa<ReturnOp>(currentOp))
+ return WalkResult::advance();
+ if (auto interfaceOp =
+ dyn_cast<WasmSSAConstantExprCheckInterface>(currentOp)) {
+ if (interfaceOp.isValidInConstantExpr().succeeded())
+ return WalkResult::advance();
+ }
+ op->emitError("Expected a constant initializer for this operator, got ")
+ << currentOp;
+ return WalkResult::interrupt();
+ });
+ return success(!resultState.wasInterrupted());
+}
+} // namespace detail
+
+llvm::FailureOr<WasmSSALabelLevelInterface>
+WasmSSALabelBranchingInterface::getTargetOpFromBlock(::mlir::Block *block,
+ uint32_t breakLevel) {
+ WasmSSALabelLevelInterface res{};
+ for (size_t curLevel{0}; curLevel <= breakLevel; curLevel++) {
+ res = dyn_cast_or_null<WasmSSALabelLevelInterface>(block->getParentOp());
+ if (!res)
+ return failure();
+ block = res->getBlock();
+ }
+ return res;
+}
+} // namespace mlir::wasmssa
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp
new file mode 100644
index 0000000000000..00bb83caa9ee6
--- /dev/null
+++ b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp
@@ -0,0 +1,510 @@
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/Support/Casting.h"
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp.inc"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/Support/LogicalResult.h"
+
+using namespace mlir;
+using namespace mlir::wasmssa;
+
+namespace {
+inline LogicalResult
+inferTeeGetResType(ValueRange operands,
+ ::llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands.empty())
+ return failure();
+ auto opType = llvm::dyn_cast<LocalRefType>(operands.front().getType());
+ if (!opType)
+ return failure();
+ inferredReturnTypes.push_back(opType.getElementType());
+ return success();
+}
+
+ParseResult parseImportOp(OpAsmParser &parser, OperationState &result) {
+ std::string importName;
+ auto *ctx = parser.getContext();
+ ParseResult res = parser.parseString(&importName);
+ result.addAttribute("importName", StringAttr::get(ctx, importName));
+
+ std::string fromStr;
+ res = parser.parseKeywordOrString(&fromStr);
+ if (failed(res) || fromStr != "from")
+ return failure();
+
+ std::string moduleName;
+ res = parser.parseString(&moduleName);
+ if (failed(res))
+ return failure();
+ result.addAttribute("moduleName", StringAttr::get(ctx, moduleName));
+
+ std::string asStr;
+ res = parser.parseKeywordOrString(&asStr);
+ if (failed(res) || asStr != "as")
+ return failure();
+
+ StringAttr symbolName;
+ res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
+ result.attributes);
+ return res;
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// BlockOp
+//===----------------------------------------------------------------------===//
+
+Block *BlockOp::getLabelTarget() { return getTarget(); }
+
+//===----------------------------------------------------------------------===//
+// BlockReturnOp
+//===----------------------------------------------------------------------===//
+
+std::size_t BlockReturnOp::getExitLevel() { return 0; }
+
+Block *BlockReturnOp::getTarget() {
+ return cast<WasmSSALabelBranchingInterface>(getOperation())
+ .getTargetOp()
+ .getOperation()
+ ->getSuccessor(0);
+}
+
+//===----------------------------------------------------------------------===//
+// ExtendLowBitsSOp
+//===----------------------------------------------------------------------===//
+
+ParseResult ExtendLowBitsSOp::parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ OpAsmParser::UnresolvedOperand operand;
+ uint64_t nBits;
+ auto parseRes = parser.parseInteger(nBits);
+ parseRes = parser.parseKeyword("low");
+ parseRes = parser.parseKeyword("bits");
+ parseRes = parser.parseKeyword("from");
+ parseRes = parser.parseOperand(operand);
+ parseRes = parser.parseColon();
+ Type inType;
+ parseRes = parser.parseType(inType);
+ if (!inType.isInteger())
+ return failure();
+ llvm::SmallVector<Value, 1> opVal;
+ parseRes = parser.resolveOperand(operand, inType, opVal);
+ if (parseRes.failed())
+ return failure();
+ result.addOperands(opVal);
+ result.addAttribute(
+ ExtendLowBitsSOp::getBitsToTakeAttrName(OperationName{
+ ExtendLowBitsSOp::getOperationName(), parser.getContext()}),
+ parser.getBuilder().getI64IntegerAttr(nBits));
+ result.addTypes(inType);
+ return success();
+}
+
+void ExtendLowBitsSOp::print(OpAsmPrinter &p) {
+ p << " " << getBitsToTake().getUInt() << " low bits from ";
+ p.printOperand(getInput());
+ p << ": " << getInput().getType();
+}
+
+LogicalResult ExtendLowBitsSOp::verify() {
+ auto bitsToTake = getBitsToTake().getValue().getLimitedValue();
+ if (bitsToTake != 32 && bitsToTake != 16 && bitsToTake != 8)
+ return emitError("Extend op can only take 8, 16 or 32 bits. Got ")
+ << bitsToTake;
+
+ if (bitsToTake >= getInput().getType().getIntOrFloatBitWidth())
+ return emitError("Trying to extend the ")
+ << bitsToTake << " low bits from a " << getInput().getType()
+ << " value";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+Block *FuncOp::addEntryBlock() {
+ if (!getBody().empty()) {
+ emitError("Adding entry block to a FuncOp which already has one.");
+ return &getBody().front();
+ }
+ Block &block = getBody().emplaceBlock();
+ for (auto argType : getFunctionType().getInputs())
+ block.addArgument(LocalRefType::get(argType), getLoc());
+ return █
+}
+
+void FuncOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, llvm::StringRef symbol,
+ FunctionType funcType) {
+ odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
+ odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
+ odsState.addAttribute("functionType", TypeAttr::get(funcType));
+ odsState.addRegion();
+}
+
+ParseResult FuncOp::parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes,
+ ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) {
+ llvm::SmallVector<Type> argTypesWithoutLocal{};
+ argTypesWithoutLocal.reserve(argTypes.size());
+ llvm::for_each(argTypes, [&parser, &argTypesWithoutLocal](Type argType) {
+ auto refType = dyn_cast<LocalRefType>(argType);
+ auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+ if (!refType) {
+ mlir::emitError(loc, "Invalid type for wasm.func argument. Expecting "
+ "!wasm<local T>, got ")
+ << argType << ".";
+ return;
+ }
+ argTypesWithoutLocal.push_back(refType.getElementType());
+ });
+
+ return builder.getFunctionType(argTypesWithoutLocal, results);
+ };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+LogicalResult FuncOp::verifyBody() {
+ if (getBody().empty())
+ return success();
+ Block &entry = getBody().front();
+ if (entry.getNumArguments() != getFunctionType().getNumInputs())
+ return emitError("Entry block should have same number of arguments as "
+ "function type. Function type has ")
+ << getFunctionType().getNumInputs() << ", entry block has "
+ << entry.getNumArguments() << ".";
+
+ for (auto [argNo, funcSignatureType, blockType] : llvm::enumerate(
+ getFunctionType().getInputs(), entry.getArgumentTypes())) {
+ auto blockLocalRefType = dyn_cast<LocalRefType>(blockType);
+ if (!blockLocalRefType)
+ return emitError("Entry block argument type should be LocalRefType, got ")
+ << blockType << " for block argument " << argNo << ".";
+ if (blockLocalRefType.getElementType() != funcSignatureType)
+ return emitError("Func argument type #")
+ << argNo << "(" << funcSignatureType
+ << ") doesn't match entry block referenced type ("
+ << blockLocalRefType.getElementType() << ").";
+ }
+ return success();
+}
+
+void FuncOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
+}
+
+//===----------------------------------------------------------------------===//
+// FuncImportOp
+//===----------------------------------------------------------------------===//
+
+void FuncImportOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, StringRef symbol,
+ StringRef moduleName, StringRef importName,
+ FunctionType type) {
+ odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
+ odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
+ odsState.addAttribute("moduleName", odsBuilder.getStringAttr(moduleName));
+ odsState.addAttribute("importName", odsBuilder.getStringAttr(importName));
+ odsState.addAttribute("type", TypeAttr::get(type));
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalOp
+//===----------------------------------------------------------------------===//
+
+void GlobalOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, llvm::StringRef symbol,
+ Type type, bool isMutable) {
+ odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
+ odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
+ odsState.addAttribute("type", TypeAttr::get(type));
+ if (isMutable)
+ odsState.addAttribute("isMutable", odsBuilder.getUnitAttr());
+ odsState.addRegion();
+}
+
+// Custom formats
+ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
+ StringAttr symbolName;
+ Type globalType;
+ auto *ctx = parser.getContext();
+ auto res = parser.parseSymbolName(
+ symbolName, SymbolTable::getSymbolAttrName(), result.attributes);
+
+ res = parser.parseType(globalType);
+ result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType));
+ std::string mutableString;
+ res = parser.parseOptionalKeywordOrString(&mutableString);
+ if (res.succeeded() && mutableString == "mutable")
+ result.addAttribute("isMutable", UnitAttr::get(ctx));
+ std::string visibilityString;
+ res = parser.parseOptionalKeywordOrString(&visibilityString);
+ if (res.succeeded())
+ result.addAttribute("sym_visibility",
+ StringAttr::get(ctx, visibilityString));
+ res = parser.parseColon();
+ Region *globalInitRegion = result.addRegion();
+ res = parser.parseRegion(*globalInitRegion);
+ return res;
+}
+
+void GlobalOp::print(OpAsmPrinter &printer) {
+ printer << " @" << getSymName().str() << " " << getType();
+ if (getIsMutable())
+ printer << " mutable";
+ if (auto vis = getSymVisibility())
+ printer << " " << *vis;
+ printer << " :";
+ Region &body = getRegion();
+ if (!body.empty()) {
+ printer << ' ';
+ printer.printRegion(body, /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalGetOp
+//===----------------------------------------------------------------------===//
+
+// Custom interface overrides
+LogicalResult GlobalGetOp::verifyConstantExprValidity() {
+ StringRef referencedSymbol = getGlobal();
+ Operation *symTableOp =
+ getOperation()->getParentWithTrait<OpTrait::SymbolTable>();
+ Operation *definitionOp =
+ SymbolTable::lookupSymbolIn(symTableOp, referencedSymbol);
+ if (!definitionOp)
+ return failure();
+ auto definitionImport = llvm::dyn_cast<GlobalImportOp>(definitionOp);
+ if (!definitionImport || definitionImport.getIsMutable()) {
+ return emitError("global.get op is considered constant if it's referring "
+ "to a import.global symbol marked non-mutable.");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalImportOp
+//===----------------------------------------------------------------------===//
+
+void GlobalImportOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, StringRef symbol,
+ StringRef moduleName, StringRef importName,
+ Type type, bool isMutable) {
+ odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
+ odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
+ odsState.addAttribute("moduleName", odsBuilder.getStringAttr(moduleName));
+ odsState.addAttribute("importName", odsBuilder.getStringAttr(importName));
+ odsState.addAttribute("type", TypeAttr::get(type));
+ if (isMutable)
+ odsState.addAttribute("isMutable", odsBuilder.getUnitAttr());
+}
+
+ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto *ctx = parser.getContext();
+ ParseResult res = parseImportOp(parser, result);
+ if (res.failed())
+ return failure();
+ std::string mutableOrSymVisString;
+ res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
+ if (res.succeeded() && mutableOrSymVisString == "mutable") {
+ result.addAttribute("isMutable", UnitAttr::get(ctx));
+ res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
+ }
+
+ if (res.succeeded())
+ result.addAttribute("sym_visibility",
+ StringAttr::get(ctx, mutableOrSymVisString));
+ res = parser.parseColon();
+
+ Type importedType;
+ res = parser.parseType(importedType);
+ if (res.succeeded())
+ result.addAttribute(getTypeAttrName(result.name),
+ TypeAttr::get(importedType));
+ return res;
+}
+
+void GlobalImportOp::print(OpAsmPrinter &printer) {
+ printer << " \"" << getImportName() << "\" from \"" << getModuleName()
+ << "\" as @" << getSymName();
+ if (getIsMutable())
+ printer << " mutable";
+ if (auto vis = getSymVisibility())
+ printer << " " << *vis;
+ printer << " : " << getType();
+}
+
+//===----------------------------------------------------------------------===//
+// IfOp
+//===----------------------------------------------------------------------===//
+
+Block *IfOp::getLabelTarget() { return getTarget(); }
+
+//===----------------------------------------------------------------------===//
+// LocalOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult LocalOp::inferReturnTypes(
+ MLIRContext *context, ::std::optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
+ RegionRange regions, ::llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
+ LocalOp::GenericAdaptor<ValueRange> adaptor{operands, attributes, properties,
+ regions};
+ auto type = adaptor.getTypeAttr();
+ if (!type)
+ return failure();
+ auto resType = LocalRefType::get(type.getContext(), type.getValue());
+ inferredReturnTypes.push_back(resType);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LocalGetOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult LocalGetOp::inferReturnTypes(
+ MLIRContext *context, ::std::optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
+ RegionRange regions, ::llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferTeeGetResType(operands, inferredReturnTypes);
+}
+
+LogicalResult LocalGetOp::verify() {
+ return success(getLocalVar().getType().getElementType() ==
+ getResult().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// LocalSetOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult LocalSetOp::verify() {
+ return success(getLocalVar().getType().getElementType() ==
+ getValue().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// LocalTeeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult LocalTeeOp::inferReturnTypes(
+ MLIRContext *context, ::std::optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
+ RegionRange regions, ::llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferTeeGetResType(operands, inferredReturnTypes);
+}
+
+LogicalResult LocalTeeOp::verify() {
+ return success(getLocalVar().getType().getElementType() ==
+ getValue().getType() &&
+ getValue().getType() == getResult().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+Block *LoopOp::getLabelTarget() { return &getBody().front(); }
+
+//===----------------------------------------------------------------------===//
+// MemOp
+//===----------------------------------------------------------------------===//
+
+void MemOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, llvm::StringRef symbol,
+ LimitType limit) {
+ odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
+ odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
+ odsState.addAttribute("limits", TypeAttr::get(limit));
+}
+
+//===----------------------------------------------------------------------===//
+// MemImportOp
+//===----------------------------------------------------------------------===//
+
+void MemImportOp::build(mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ llvm::StringRef symbol, llvm::StringRef moduleName,
+ llvm::StringRef importName, LimitType limits) {
+ odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
+ odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
+ odsState.addAttribute("moduleName", odsBuilder.getStringAttr(moduleName));
+ odsState.addAttribute("importName", odsBuilder.getStringAttr(importName));
+ odsState.addAttribute("limits", TypeAttr::get(limits));
+}
+
+//===----------------------------------------------------------------------===//
+// ReinterpretOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReinterpretOp::verify() {
+ auto inT = getInput().getType();
+ auto resT = getResult().getType();
+ if (inT == resT)
+ return emitError("reinterpret input and output type should be distinct.");
+ if (inT.getIntOrFloatBitWidth() != resT.getIntOrFloatBitWidth())
+ return emitError() << "input type (" << inT << ") and output type (" << resT
+ << ") have incompatible bit widths.";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+void ReturnOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState) {}
+
+//===----------------------------------------------------------------------===//
+// TableOp
+//===----------------------------------------------------------------------===//
+
+void TableOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, llvm::StringRef symbol,
+ TableType type) {
+ odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
+ odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
+ odsState.addAttribute("type", TypeAttr::get(type));
+}
+
+//===----------------------------------------------------------------------===//
+// TableImportOp
+//===----------------------------------------------------------------------===//
+
+void TableImportOp::build(mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ llvm::StringRef symbol, llvm::StringRef moduleName,
+ llvm::StringRef importName, TableType type) {
+ odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
+ odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
+ odsState.addAttribute("moduleName", odsBuilder.getStringAttr(moduleName));
+ odsState.addAttribute("importName", odsBuilder.getStringAttr(importName));
+ odsState.addAttribute("type", TypeAttr::get(type));
+}
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.cpp b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.cpp
new file mode 100644
index 0000000000000..27b3af5af0a6f
--- /dev/null
+++ b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.cpp
@@ -0,0 +1,36 @@
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/Support/LogicalResult.h"
+
+#include <optional>
+
+namespace mlir::wasmssa {
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypeConstraints.cpp.inc"
+} // namespace mlir::wasmssa
+
+using namespace mlir;
+using namespace mlir::wasmssa;
+
+Type LimitType::parse(::mlir::AsmParser &parser) {
+ auto res = parser.parseLSquare();
+ uint32_t minLimit{0};
+ std::optional<uint32_t> maxLimit{std::nullopt};
+ res = parser.parseInteger(minLimit);
+ res = parser.parseColon();
+ uint32_t maxValue{0};
+ auto maxParseRes = parser.parseOptionalInteger(maxValue);
+ if (maxParseRes.has_value() && (*maxParseRes).succeeded())
+ maxLimit = maxValue;
+
+ res = parser.parseRSquare();
+ return LimitType::get(parser.getContext(), minLimit, maxLimit);
+}
+
+void LimitType::print(AsmPrinter &printer) const {
+ printer << '[' << getMin() << ':';
+ auto maxLim = getMax();
+ if (maxLim)
+ printer << *maxLim;
+ printer << ']';
+}
diff --git a/mlir/test/Dialect/WebAssemblySSA/custom_parser/global.mlir b/mlir/test/Dialect/WebAssemblySSA/custom_parser/global.mlir
new file mode 100644
index 0000000000000..b9b342052ff1b
--- /dev/null
+++ b/mlir/test/Dialect/WebAssemblySSA/custom_parser/global.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+module {
+ wasmssa.import_global "from_js" from "env" as @global_0 nested : i32
+
+ wasmssa.global @global_1 i32 : {
+ %0 = wasmssa.const 10 : i32
+ wasmssa.return %0 : i32
+ }
+ wasmssa.global @global_2 i32 mutable : {
+ %0 = wasmssa.const 17 : i32
+ wasmssa.return %0 : i32
+ }
+ wasmssa.global @global_3 i32 mutable : {
+ %0 = wasmssa.const 10 : i32
+ wasmssa.return %0 : i32
+ }
+ wasmssa.global @global_4 i32 : {
+ %0 = wasmssa.global_get @global_0 : i32
+ wasmssa.return %0 : i32
+ }
+}
+
+// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 nested : i32
+
+// CHECK-LABEL: wasmssa.global @global_1 i32 : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.global @global_2 i32 mutable : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 17 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.global @global_3 i32 mutable : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.global @global_4 i32 : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.global_get @global_0 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
+// CHECK: }
diff --git a/mlir/test/Dialect/WebAssemblySSA/custom_parser/global_illegal.mlir b/mlir/test/Dialect/WebAssemblySSA/custom_parser/global_illegal.mlir
new file mode 100644
index 0000000000000..c824593e8cdf5
--- /dev/null
+++ b/mlir/test/Dialect/WebAssemblySSA/custom_parser/global_illegal.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -verify-diagnostics --split-input-file
+
+module {
+ // expected-error at +1 {{Expected a constant initializer for this operator}}
+ wasmssa.global @illegal i32 mutable : {
+ %0 = wasmssa.const 17: i32
+ %1 = wasmssa.const 35: i32
+ %2 = wasmssa.add %0 %1 : i32
+ wasmssa.return %2 : i32
+ }
+}
+
+// -----
+
+module {
+ wasmssa.import_global "glob" from "my_module" as @global_0 mutable nested : i32
+ // expected-error at +1 {{Expected a constant initializer for this operator}}
+ wasmssa.global @global_1 i32 : {
+ // expected-error at +1 {{global.get op is considered constant if it's referring to a import.global symbol marked non-mutable}}
+ %0 = wasmssa.global_get @global_0 : i32
+ wasmssa.return %0 : i32
+ }
+}
diff --git a/mlir/test/Dialect/WebAssemblySSA/custom_parser/import.mlir b/mlir/test/Dialect/WebAssemblySSA/custom_parser/import.mlir
new file mode 100644
index 0000000000000..3cc05486c4a4d
--- /dev/null
+++ b/mlir/test/Dialect/WebAssemblySSA/custom_parser/import.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+module {
+ wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()}
+ wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()}
+ wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
+ wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"}
+ wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32
+ wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32
+}
+
+// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()}
+// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()}
+// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
+// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"}
+// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32
+// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32
diff --git a/mlir/test/Dialect/WebAssemblySSA/custom_parser/local.mlir b/mlir/test/Dialect/WebAssemblySSA/custom_parser/local.mlir
new file mode 100644
index 0000000000000..3f6423fa2a1f5
--- /dev/null
+++ b/mlir/test/Dialect/WebAssemblySSA/custom_parser/local.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+module {
+ wasmssa.func nested @func_0() -> f32 {
+ %0 = wasmssa.local of type f32
+ %1 = wasmssa.local of type f32
+ %2 = wasmssa.const 8.000000e+00 : f32
+ %3 = wasmssa.const 1.200000e+01 : f32
+ %4 = wasmssa.add %2 %3 : f32
+ wasmssa.return %4 : f32
+ }
+ wasmssa.func nested @func_1() -> i32 {
+ %0 = wasmssa.local of type i32
+ %1 = wasmssa.local of type i32
+ %2 = wasmssa.const 8 : i32
+ %3 = wasmssa.const 12 : i32
+ %4 = wasmssa.add %2 %3 : i32
+ wasmssa.return %4 : i32
+ }
+ wasmssa.func nested @func_2(%arg0: !wasmssa<local ref to i32>) -> i32 {
+ %0 = wasmssa.const 3 : i32
+ wasmssa.return %0 : i32
+ }
+}
+
+// CHECK-LABEL: wasmssa.func nested @func_0() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local of type f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.local of type f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 8.000000e+00 : f32
+// CHECK: %[[VAL_3:.*]] = wasmssa.const 1.200000e+01 : f32
+// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_3]] : f32
+// CHECK: wasmssa.return %[[VAL_4]] : f32
+
+// CHECK-LABEL: wasmssa.func nested @func_1() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 8 : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_3]] : i32
+// CHECK: wasmssa.return %[[VAL_4]] : i32
+
+// CHECK-LABEL: wasmssa.func nested @func_2(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
>From f79a7d51a2fe052649df0010224978b515a923f7 Mon Sep 17 00:00:00 2001
From: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Date: Tue, 22 Jul 2025 13:27:19 +0900
Subject: [PATCH 02/11] [mlir][wasm] Apply review comments regarding formatting
and documentation
---
.../IR/WebAssemblySSAInterfaces.td | 10 ++---
.../WebAssemblySSA/IR/WebAssemblySSAOps.td | 40 ++++++++++---------
.../WebAssemblySSA/IR/WebAssemblySSATypes.td | 2 -
.../IR/WebAssemblySSAInterfaces.cpp | 9 +++--
.../WebAssemblySSA/IR/WebAssemblySSAOps.cpp | 26 +++++++-----
.../WebAssemblySSA/IR/WebAssemblySSATypes.cpp | 10 ++++-
...lobal_illegal.mlir => global-illegal.mlir} | 4 +-
7 files changed, 59 insertions(+), 42 deletions(-)
rename mlir/test/Dialect/WebAssemblySSA/custom_parser/{global_illegal.mlir => global-illegal.mlir} (82%)
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
index 7857556871c3f..1e6bb0391d30c 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
@@ -1,4 +1,4 @@
-//===-- WebAssemblySSAInterfaces.td - WebAssemblySSA Interfaces -*- tablegen -*-===//
+//===-- WebAssemblySSAInterfaces.td - WasmSSA Interfaces -*- tablegen -*--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -16,7 +16,7 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributes.td"
-def WasmSSALabelLevelInterface : OpInterface<"WasmSSALabelLevelInterface"> {
+def WasmSSALabelLevelOpInterface : OpInterface<"WasmSSALabelLevelInterface"> {
let description = [{
Operation that defines one level of nesting for wasm branching.
These operation region can be targeted by branch instructions.
@@ -31,7 +31,7 @@ def WasmSSALabelLevelInterface : OpInterface<"WasmSSALabelLevelInterface"> {
];
}
-def WasmSSALabelBranchingInterface : OpInterface<"WasmSSALabelBranchingInterface"> {
+def WasmSSALabelBranchingOpInterface : OpInterface<"WasmSSALabelBranchingInterface"> {
let description = [{
Wasm operation that targets a label for a jump.
}];
@@ -104,7 +104,7 @@ def WasmSSAImportOpInterface : OpInterface<"WasmSSAImportOpInterface"> {
/*methodName=*/ "getQualifiedImportName",
/*args=*/ (ins),
/*methodBody=*/ [{
- return ($_op.getModuleName() + llvm::Twine{"::"} + $_op.getImportName()).str();
+ return ($_op.getModuleName() + ::llvm::Twine{"::"} + $_op.getImportName()).str();
}]
>,
];
@@ -183,4 +183,4 @@ def WasmSSAConstantExprInterface :
];
}
-#endif // WEBASSEMBLYSSA_INTERFACES
+#endif // WEBASSEMBLY_INTERFACES
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
index 9e370920f3173..4f39a99a6b052 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
@@ -19,11 +19,15 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
+// Base class for WasmSSA operations.
+// Most operations are made to match 1:1, only ignoring the stack-based approach of Wasm
+// for an SSA based approach. In cases where operations match 1:1 the Wasm spec,
+// no description is provided.
class WasmSSA_Op<string mnemonic, list<Trait> traits = []> :
Op<WasmSSA_Dialect, mnemonic, traits>;
class WasmSSA_BlockLikeOp<string mnemonic, string summaryStr> :
- WasmSSA_Op<mnemonic, [Terminator, DeclareOpInterfaceMethods<WasmSSALabelLevelInterface>]> {
+ WasmSSA_Op<mnemonic, [Terminator, DeclareOpInterfaceMethods<WasmSSALabelLevelOpInterface>]> {
let summary = summaryStr;
let arguments = (ins Variadic<WasmSSA_ValType>: $inputs);
let regions = (region AnyRegion: $body);
@@ -44,7 +48,7 @@ def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<"block", "Create a nesting level"> {}
def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<"loop", "Create a nesting level similar to Block Op, except that it has itself as a successor."> {}
def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
- DeclareOpInterfaceMethods<WasmSSALabelBranchingInterface>]> {
+ DeclareOpInterfaceMethods<WasmSSALabelBranchingOpInterface>]> {
let summary = "Return from the current block";
let arguments = (ins Variadic<WasmSSA_ValType>: $inputs);
let extraClassDeclaration = [{
@@ -55,7 +59,7 @@ def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
def WasmSSA_BranchIfOp : WasmSSA_Op<"branch_if", [
Terminator,
- DeclareOpInterfaceMethods<WasmSSALabelBranchingInterface>]> {
+ DeclareOpInterfaceMethods<WasmSSALabelBranchingOpInterface>]> {
let summary = "Jump to target level if condition has non-zero value";
let arguments = (ins I32: $condition,
UI32Attr: $exitLevel,
@@ -74,7 +78,7 @@ def WasmSSA_ConstOp : WasmSSA_Op<"const", [
def WasmSSA_FuncOp : WasmSSA_Op<"func", [
AffineScope, AutomaticAllocationScope,
- DeclareOpInterfaceMethods<FunctionOpInterface>,
+ DeclareOpInterfaceMethods<FunctionOpInterface, ["verifyBody"]>,
IsolatedFromAbove,
Symbol]> {
let description = [{
@@ -111,8 +115,6 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
/// function.
::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
- ::mlir::LogicalResult verifyBody();
-
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
@@ -121,7 +123,7 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
}];
let builders = [
- OpBuilder<(ins "llvm::StringRef":$symbol, "FunctionType":$funcType )>
+ OpBuilder<(ins "::llvm::StringRef":$symbol, "FunctionType":$funcType )>
];
let hasCustomAssemblyFormat = 1;
}
@@ -154,11 +156,11 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
Region *getCallableRegion() { return nullptr; }
- llvm::ArrayRef<Type> getArgumentTypes() {
+ ::llvm::ArrayRef<Type> getArgumentTypes() {
return getType().getInputs();
}
- llvm::ArrayRef<Type> getResultTypes() {
+ ::llvm::ArrayRef<Type> getResultTypes() {
return getType().getResults();
}
}];
@@ -224,7 +226,7 @@ def WasmSSA_GlobalGetOp : WasmSSA_Op<"global_get", [DeclareOpInterfaceMethods<Wa
}
def WasmSSA_IfOp : WasmSSA_Op<"if", [Terminator,
- DeclareOpInterfaceMethods<WasmSSALabelLevelInterface>]> {
+ DeclareOpInterfaceMethods<WasmSSALabelLevelOpInterface>]> {
let summary = "Execute the if region if condition value is nonzero, the else region otherwise.";
let arguments = (ins I32:$condition, Variadic<WasmSSA_ValType>: $inputs);
let regions = (region AnyRegion: $if, AnyRegion: $else);
@@ -293,8 +295,8 @@ def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> {
OptionalAttr<StrAttr>:$sym_visibility);
let builders = [
OpBuilder<(ins
- "llvm::StringRef":$symbol,
- "LimitType":$limit)>
+ "::llvm::StringRef":$symbol,
+ "wasmssa::LimitType":$limit)>
];
}
@@ -309,9 +311,9 @@ def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, WasmSSAImportOpInter
bool isDeclaration() const { return true; }
}];
let builders = [OpBuilder<(ins
- "llvm::StringRef":$symbol,
- "llvm::StringRef":$moduleName,
- "llvm::StringRef":$importName,
+ "::llvm::StringRef":$symbol,
+ "::llvm::StringRef":$moduleName,
+ "::llvm::StringRef":$importName,
"wasmssa::LimitType":$limits)>];
let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
}
@@ -322,7 +324,7 @@ def WasmSSA_TableOp : WasmSSA_Op<"table", [Symbol]> {
WasmSSA_TableTypeAttr: $type,
OptionalAttr<StrAttr>:$sym_visibility);
let builders = [OpBuilder<(ins
- "llvm::StringRef":$symbol,
+ "::llvm::StringRef":$symbol,
"wasmssa::TableType":$type)>];
}
@@ -338,9 +340,9 @@ def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, WasmSSAImportOpI
}];
let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
let builders = [OpBuilder<(ins
- "llvm::StringRef":$symbol,
- "llvm::StringRef":$moduleName,
- "llvm::StringRef":$importName,
+ "::llvm::StringRef":$symbol,
+ "::llvm::StringRef":$moduleName,
+ "::llvm::StringRef":$importName,
"wasmssa::TableType":$type)>];
}
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td
index 1ab2196d70b34..e5ddcb743b034 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td
@@ -80,7 +80,5 @@ def WasmSSA_ValTypeAttr : TypeAttrOf<WasmSSA_ValType>;
def WasmSSA_IntegerAttr : AnyAttrOf<[I32Attr, I64Attr]>;
def WasmSSA_FPAttr : AnyAttrOf<[F32Attr, F64Attr]>;
def WasmSSA_NumericAttr : AnyAttrOf<[WasmSSA_IntegerAttr, WasmSSA_FPAttr]>;
-def WasmSSA_VecAttr : TypedSignlessIntegerAttrBase<
- I128, "::llvm::APInt", "128 bits signless integer attribute">;
#endif // WEBASSEMBLYSSA_TYPES
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp
index 0b43f2f671062..2dc75de234857 100644
--- a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp
+++ b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp
@@ -22,14 +22,15 @@ namespace mlir::wasmssa {
namespace detail {
LogicalResult verifyWasmSSALabelBranchingInterface(Operation *op) {
auto branchInterface = dyn_cast<WasmSSALabelBranchingInterface>(op);
- auto res = WasmSSALabelBranchingInterface::getTargetOpFromBlock(
- op->getBlock(), branchInterface.getExitLevel());
+ llvm::FailureOr<WasmSSALabelLevelInterface> res =
+ WasmSSALabelBranchingInterface::getTargetOpFromBlock(
+ op->getBlock(), branchInterface.getExitLevel());
return success(succeeded(res));
}
LogicalResult verifyConstantExpressionInterface(Operation *op) {
Region &initializerRegion = op->getRegion(0);
- auto resultState =
+ WalkResult resultState =
initializerRegion.walk([&](Operation *currentOp) -> WalkResult {
if (isa<ReturnOp>(currentOp))
return WalkResult::advance();
@@ -38,7 +39,7 @@ LogicalResult verifyConstantExpressionInterface(Operation *op) {
if (interfaceOp.isValidInConstantExpr().succeeded())
return WalkResult::advance();
}
- op->emitError("Expected a constant initializer for this operator, got ")
+ op->emitError("expected a constant initializer for this operator, got ")
<< currentOp;
return WalkResult::interrupt();
});
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp
index 00bb83caa9ee6..186632c4e3ddc 100644
--- a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp
+++ b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp
@@ -1,3 +1,11 @@
+//===- WebAssemblySSAOps.cpp - WasmSSA dialect operations ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
#include "mlir/IR/Attributes.h"
@@ -92,7 +100,7 @@ ParseResult ExtendLowBitsSOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
OpAsmParser::UnresolvedOperand operand;
uint64_t nBits;
- auto parseRes = parser.parseInteger(nBits);
+ ParseResult parseRes = parser.parseInteger(nBits);
parseRes = parser.parseKeyword("low");
parseRes = parser.parseKeyword("bits");
parseRes = parser.parseKeyword("from");
@@ -124,11 +132,11 @@ void ExtendLowBitsSOp::print(OpAsmPrinter &p) {
LogicalResult ExtendLowBitsSOp::verify() {
auto bitsToTake = getBitsToTake().getValue().getLimitedValue();
if (bitsToTake != 32 && bitsToTake != 16 && bitsToTake != 8)
- return emitError("Extend op can only take 8, 16 or 32 bits. Got ")
+ return emitError("extend op can only take 8, 16 or 32 bits. Got ")
<< bitsToTake;
if (bitsToTake >= getInput().getType().getIntOrFloatBitWidth())
- return emitError("Trying to extend the ")
+ return emitError("trying to extend the ")
<< bitsToTake << " low bits from a " << getInput().getType()
<< " value";
return success();
@@ -140,7 +148,7 @@ LogicalResult ExtendLowBitsSOp::verify() {
Block *FuncOp::addEntryBlock() {
if (!getBody().empty()) {
- emitError("Adding entry block to a FuncOp which already has one.");
+ emitError("adding entry block to a FuncOp which already has one.");
return &getBody().front();
}
Block &block = getBody().emplaceBlock();
@@ -170,7 +178,7 @@ ParseResult FuncOp::parse(::mlir::OpAsmParser &parser,
auto refType = dyn_cast<LocalRefType>(argType);
auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
if (!refType) {
- mlir::emitError(loc, "Invalid type for wasm.func argument. Expecting "
+ mlir::emitError(loc, "invalid type for wasm.func argument. Expecting "
"!wasm<local T>, got ")
<< argType << ".";
return;
@@ -192,7 +200,7 @@ LogicalResult FuncOp::verifyBody() {
return success();
Block &entry = getBody().front();
if (entry.getNumArguments() != getFunctionType().getNumInputs())
- return emitError("Entry block should have same number of arguments as "
+ return emitError("entry block should have same number of arguments as "
"function type. Function type has ")
<< getFunctionType().getNumInputs() << ", entry block has "
<< entry.getNumArguments() << ".";
@@ -201,10 +209,10 @@ LogicalResult FuncOp::verifyBody() {
getFunctionType().getInputs(), entry.getArgumentTypes())) {
auto blockLocalRefType = dyn_cast<LocalRefType>(blockType);
if (!blockLocalRefType)
- return emitError("Entry block argument type should be LocalRefType, got ")
+ return emitError("entry block argument type should be LocalRefType, got ")
<< blockType << " for block argument " << argNo << ".";
if (blockLocalRefType.getElementType() != funcSignatureType)
- return emitError("Func argument type #")
+ return emitError("func argument type #")
<< argNo << "(" << funcSignatureType
<< ") doesn't match entry block referenced type ("
<< blockLocalRefType.getElementType() << ").";
@@ -253,7 +261,7 @@ ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
StringAttr symbolName;
Type globalType;
auto *ctx = parser.getContext();
- auto res = parser.parseSymbolName(
+ ParseResult res = parser.parseSymbolName(
symbolName, SymbolTable::getSymbolAttrName(), result.attributes);
res = parser.parseType(globalType);
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.cpp b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.cpp
index 27b3af5af0a6f..f5f5d80c09ab9 100644
--- a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.cpp
+++ b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.cpp
@@ -1,3 +1,11 @@
+//===- WebAssemblySSAOps.cpp - WasmSSA dialect operations ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
@@ -29,7 +37,7 @@ Type LimitType::parse(::mlir::AsmParser &parser) {
void LimitType::print(AsmPrinter &printer) const {
printer << '[' << getMin() << ':';
- auto maxLim = getMax();
+ std::optional<uint32_t> maxLim = getMax();
if (maxLim)
printer << *maxLim;
printer << ']';
diff --git a/mlir/test/Dialect/WebAssemblySSA/custom_parser/global_illegal.mlir b/mlir/test/Dialect/WebAssemblySSA/custom_parser/global-illegal.mlir
similarity index 82%
rename from mlir/test/Dialect/WebAssemblySSA/custom_parser/global_illegal.mlir
rename to mlir/test/Dialect/WebAssemblySSA/custom_parser/global-illegal.mlir
index c824593e8cdf5..3571565564b6d 100644
--- a/mlir/test/Dialect/WebAssemblySSA/custom_parser/global_illegal.mlir
+++ b/mlir/test/Dialect/WebAssemblySSA/custom_parser/global-illegal.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -verify-diagnostics --split-input-file
module {
- // expected-error at +1 {{Expected a constant initializer for this operator}}
+ // expected-error at +1 {{expected a constant initializer for this operator}}
wasmssa.global @illegal i32 mutable : {
%0 = wasmssa.const 17: i32
%1 = wasmssa.const 35: i32
@@ -14,7 +14,7 @@ module {
module {
wasmssa.import_global "glob" from "my_module" as @global_0 mutable nested : i32
- // expected-error at +1 {{Expected a constant initializer for this operator}}
+ // expected-error at +1 {{expected a constant initializer for this operator}}
wasmssa.global @global_1 i32 : {
// expected-error at +1 {{global.get op is considered constant if it's referring to a import.global symbol marked non-mutable}}
%0 = wasmssa.global_get @global_0 : i32
>From ba474ce228f6b86000864c83abc9d06d053f8f67 Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Wed, 23 Jul 2025 11:29:13 +0900
Subject: [PATCH 03/11] [mlir][wasm] Apply review comments on interface
replacement by trait
---
.../IR/WebAssemblySSAInterfaces.h | 8 ++-
.../IR/WebAssemblySSAInterfaces.td | 56 +++++--------------
.../WebAssemblySSA/IR/WebAssemblySSAOps.td | 6 +-
.../IR/WebAssemblySSAInterfaces.cpp | 2 +-
.../WebAssemblySSA/IR/WebAssemblySSAOps.cpp | 3 +-
5 files changed, 26 insertions(+), 49 deletions(-)
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h
index 03c4021b1421b..769ea9ecafce0 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h
@@ -17,12 +17,16 @@
#include "mlir/IR/OpDefinition.h"
namespace mlir::wasmssa {
+
+template <class OperationType>
+struct AlwaysValidConstantExprTrait
+ : public OpTrait::TraitBase<OperationType, AlwaysValidConstantExprTrait> {};
+
namespace detail {
LogicalResult verifyConstantExpressionInterface(Operation *op);
LogicalResult verifyWasmSSALabelBranchingInterface(Operation *op);
} // namespace detail
-
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h.inc"
} // namespace mlir::wasmssa
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h.inc"
#endif // MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSAINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
index 1e6bb0391d30c..d682c035ff4ec 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
@@ -17,6 +17,7 @@ include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributes.td"
def WasmSSALabelLevelOpInterface : OpInterface<"WasmSSALabelLevelInterface"> {
+ let cppNamespace = "::mlir::wasmssa";
let description = [{
Operation that defines one level of nesting for wasm branching.
These operation region can be targeted by branch instructions.
@@ -32,6 +33,7 @@ def WasmSSALabelLevelOpInterface : OpInterface<"WasmSSALabelLevelInterface"> {
}
def WasmSSALabelBranchingOpInterface : OpInterface<"WasmSSALabelBranchingInterface"> {
+ let cppNamespace = "::mlir::wasmssa";
let description = [{
Wasm operation that targets a label for a jump.
}];
@@ -70,6 +72,7 @@ def WasmSSALabelBranchingOpInterface : OpInterface<"WasmSSALabelBranchingInterfa
}
def WasmSSAImportOpInterface : OpInterface<"WasmSSAImportOpInterface"> {
+ let cppNamespace = "::mlir::wasmssa";
let description = [{
Operation that imports a symbol from an external wasm module;
}];
@@ -112,6 +115,7 @@ def WasmSSAImportOpInterface : OpInterface<"WasmSSAImportOpInterface"> {
def WasmSSAConstantExpressionInitializerInterface :
OpInterface<"WasmSSAConstantExpressionInitializerInterface"> {
+ let cppNamespace = "::mlir::wasmssa";
let description = [{
Operation that must be constant initialized. This
interface adds a verifier that checks that all ops
@@ -122,11 +126,11 @@ def WasmSSAConstantExpressionInitializerInterface :
let verify = [{ return detail::verifyConstantExpressionInterface($_op); }];
}
-def WasmSSAConstantExprCheckInterface :
- OpInterface<"WasmSSAConstantExprCheckInterface"> {
+def ConstantExprCheckInterface :
+ OpInterface<"ConstantExprCheckInterface"> {
+ let cppNamespace = "::mlir::wasmssa";
let description = [{
Base interface for operations that can be used in a Wasm Constant Expression.
- It shouldn't be used directly, use one of the derived instead.
}];
let methods = [
@@ -136,51 +140,17 @@ def WasmSSAConstantExprCheckInterface :
}],
/*returnType=*/ "::mlir::LogicalResult",
/*methodName=*/ "isValidInConstantExpr",
- /*args=*/ (ins),
- /*methodBody=*/ [{
- return $_op.verifyConstantExprValidity();
- }]
- >
- ];
-}
-
-def WasmSSAContextuallyConstantExprInterface :
- OpInterface<"WasmSSAContextuallyConstantExprInterface", [WasmSSAConstantExprCheckInterface]> {
- let description = [{
- Base interface for operations that can be used in a Wasm Constant Expression
- depending on the context.
- }];
-
- let methods = [
- InterfaceMethod<
- /*desc=*/ [{
- Returns success if the current operation is valid in a constant expression context.
- }],
- /*returnType=*/ "::mlir::LogicalResult",
- /*methodName=*/ "verifyConstantExprValidity",
/*args=*/ (ins)
>
];
}
-def WasmSSAConstantExprInterface :
- OpInterface<"WasmSSAConstantExprInterface", [WasmSSAConstantExprCheckInterface]> {
- let description = [{
- Base interface for operations that can always be used in a Wasm Constant Expression.
- }];
-
- let methods = [
- InterfaceMethod<
- /*desc=*/ [{
- Returns success if the current operation is valid in a constant expression context.
- }],
- /*returnType=*/ "::mlir::LogicalResult",
- /*methodName=*/ "verifyConstantExprValidity",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*DefaultImplementation=*/ [{ return success(); }]
- >
- ];
+def AlwaysValidConstantExprTrait : NativeOpTrait<"AlwaysValidConstantExprTrait", [], [{
+ ::mlir::LogicalResult isValidInConstantExpr() {
+ return success();
+ }
+ }]> {
+ let cppNamespace = "::mlir::wasmssa";
}
#endif // WEBASSEMBLY_INTERFACES
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
index 4f39a99a6b052..8ce78853397f4 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
@@ -69,7 +69,9 @@ def WasmSSA_BranchIfOp : WasmSSA_Op<"branch_if", [
}
def WasmSSA_ConstOp : WasmSSA_Op<"const", [
- AllTypesMatch<["value", "result"]>, WasmSSAConstantExprInterface]> {
+ AllTypesMatch<["value", "result"]>,
+ ConstantExprCheckInterface,
+ AlwaysValidConstantExprTrait]> {
let summary = "Operator that represents a constant value";
let arguments = (ins TypedAttrInterface: $value);
let results = (outs WasmSSA_NumericType: $result);
@@ -218,7 +220,7 @@ def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [
let hasCustomAssemblyFormat = 1;
}
-def WasmSSA_GlobalGetOp : WasmSSA_Op<"global_get", [DeclareOpInterfaceMethods<WasmSSAContextuallyConstantExprInterface>]> {
+def WasmSSA_GlobalGetOp : WasmSSA_Op<"global_get", [DeclareOpInterfaceMethods<ConstantExprCheckInterface>]> {
let summary = "Returns the value of the global passed as argument.";
let arguments = (ins FlatSymbolRefAttr: $global);
let results = (outs WasmSSA_ValType: $global_val);
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp
index 2dc75de234857..e6c0957dd449c 100644
--- a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp
+++ b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp
@@ -35,7 +35,7 @@ LogicalResult verifyConstantExpressionInterface(Operation *op) {
if (isa<ReturnOp>(currentOp))
return WalkResult::advance();
if (auto interfaceOp =
- dyn_cast<WasmSSAConstantExprCheckInterface>(currentOp)) {
+ dyn_cast<ConstantExprCheckInterface>(currentOp)) {
if (interfaceOp.isValidInConstantExpr().succeeded())
return WalkResult::advance();
}
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp
index 186632c4e3ddc..bb11b4efab78a 100644
--- a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp
+++ b/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp
@@ -6,6 +6,7 @@
//
//===---------------------------------------------------------------------===//
+#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h"
#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
#include "mlir/IR/Attributes.h"
@@ -301,7 +302,7 @@ void GlobalOp::print(OpAsmPrinter &printer) {
//===----------------------------------------------------------------------===//
// Custom interface overrides
-LogicalResult GlobalGetOp::verifyConstantExprValidity() {
+LogicalResult GlobalGetOp::isValidInConstantExpr() {
StringRef referencedSymbol = getGlobal();
Operation *symTableOp =
getOperation()->getParentWithTrait<OpTrait::SymbolTable>();
>From e56c6316d546bc7766895544f87a52c54227aebe Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Wed, 23 Jul 2025 12:41:46 +0900
Subject: [PATCH 04/11] [mlir][wasm] Change ConstantExprInitializerInterface to
Trait
---
.../IR/WebAssemblySSAInterfaces.h | 17 ++++++++++++-----
.../IR/WebAssemblySSAInterfaces.td | 11 +----------
.../WebAssemblySSA/IR/WebAssemblySSAOps.td | 2 +-
3 files changed, 14 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h
index 769ea9ecafce0..f50ecd9e3910b 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h
@@ -17,15 +17,22 @@
#include "mlir/IR/OpDefinition.h"
namespace mlir::wasmssa {
-
-template <class OperationType>
-struct AlwaysValidConstantExprTrait
- : public OpTrait::TraitBase<OperationType, AlwaysValidConstantExprTrait> {};
-
namespace detail {
LogicalResult verifyConstantExpressionInterface(Operation *op);
LogicalResult verifyWasmSSALabelBranchingInterface(Operation *op);
} // namespace detail
+template <class OperationType>
+struct AlwaysValidConstantExprTrait
+ : public OpTrait::TraitBase<OperationType, AlwaysValidConstantExprTrait> {};
+
+
+template<typename OpType>
+struct ConstantExpressionInitializerTrait : public OpTrait::TraitBase<OpType, ConstantExpressionInitializerTrait>{
+ static LogicalResult verifyTrait(Operation* op) {
+ return detail::verifyConstantExpressionInterface(op);
+ }
+};
+
} // namespace mlir::wasmssa
#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
index d682c035ff4ec..6a7e451d94568 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
@@ -113,17 +113,8 @@ def WasmSSAImportOpInterface : OpInterface<"WasmSSAImportOpInterface"> {
];
}
-def WasmSSAConstantExpressionInitializerInterface :
- OpInterface<"WasmSSAConstantExpressionInitializerInterface"> {
+def ConstantExpressionInitializerTrait : NativeOpTrait<"ConstantExpressionInitializerTrait"> {
let cppNamespace = "::mlir::wasmssa";
- let description = [{
- Operation that must be constant initialized. This
- interface adds a verifier that checks that all ops
- within the initializer region are "constant expressions"
- as defined by the WASM standard.
- }];
-
- let verify = [{ return detail::verifyConstantExpressionInterface($_op); }];
}
def ConstantExprCheckInterface :
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
index 8ce78853397f4..aeb82e4932611 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
+++ b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
@@ -177,7 +177,7 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
def WasmSSA_GlobalOp : WasmSSA_Op<"global", [
AffineScope, AutomaticAllocationScope,
- IsolatedFromAbove, Symbol, WasmSSAConstantExpressionInitializerInterface]> {
+ IsolatedFromAbove, Symbol, ConstantExpressionInitializerTrait]> {
let summary= "WebAssembly global value";
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_ValTypeAttr: $type,
>From 50efecceed5d3853b1a09ac75a13e0131b1918f1 Mon Sep 17 00:00:00 2001
From: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Date: Wed, 23 Jul 2025 13:01:42 +0900
Subject: [PATCH 05/11] Rename everything to WasmSSA for consistency
---
mlir/include/mlir/Dialect/CMakeLists.txt | 2 +-
.../CMakeLists.txt | 0
.../mlir/Dialect/WasmSSA/IR/CMakeLists.txt | 13 ++++++++
.../WebAssemblySSA.h => WasmSSA/IR/WasmSSA.h} | 18 +++++------
.../IR/WasmSSABase.td} | 8 ++---
.../IR/WasmSSAInterfaces.h} | 20 ++++++-------
.../IR/WasmSSAInterfaces.td} | 30 +++++++++----------
.../IR/WasmSSAOps.td} | 22 +++++++-------
.../IR/WasmSSATypes.td} | 10 +++----
.../Dialect/WebAssemblySSA/IR/CMakeLists.txt | 13 --------
mlir/include/mlir/InitAllDialects.h | 2 +-
mlir/lib/Dialect/CMakeLists.txt | 2 +-
.../CMakeLists.txt | 0
.../IR/CMakeLists.txt | 14 ++++-----
.../IR/WasmSSADialect.cpp} | 10 +++----
.../IR/WasmSSAInterfaces.cpp} | 28 ++++++++---------
.../IR/WasmSSAOps.cpp} | 10 +++----
.../IR/WasmSSATypes.cpp} | 6 ++--
.../custom_parser/global-illegal.mlir | 0
.../custom_parser/global.mlir | 0
.../custom_parser/import.mlir | 0
.../custom_parser/local.mlir | 0
22 files changed, 104 insertions(+), 104 deletions(-)
rename mlir/include/mlir/Dialect/{WebAssemblySSA => WasmSSA}/CMakeLists.txt (100%)
create mode 100644 mlir/include/mlir/Dialect/WasmSSA/IR/CMakeLists.txt
rename mlir/include/mlir/Dialect/{WebAssemblySSA/IR/WebAssemblySSA.h => WasmSSA/IR/WasmSSA.h} (72%)
rename mlir/include/mlir/Dialect/{WebAssemblySSA/IR/WebAssemblySSABase.td => WasmSSA/IR/WasmSSABase.td} (79%)
rename mlir/include/mlir/Dialect/{WebAssemblySSA/IR/WebAssemblySSAInterfaces.h => WasmSSA/IR/WasmSSAInterfaces.h} (56%)
rename mlir/include/mlir/Dialect/{WebAssemblySSA/IR/WebAssemblySSAInterfaces.td => WasmSSA/IR/WasmSSAInterfaces.td} (80%)
rename mlir/include/mlir/Dialect/{WebAssemblySSA/IR/WebAssemblySSAOps.td => WasmSSA/IR/WasmSSAOps.td} (97%)
rename mlir/include/mlir/Dialect/{WebAssemblySSA/IR/WebAssemblySSATypes.td => WasmSSA/IR/WasmSSATypes.td} (92%)
delete mode 100644 mlir/include/mlir/Dialect/WebAssemblySSA/IR/CMakeLists.txt
rename mlir/lib/Dialect/{WebAssemblySSA => WasmSSA}/CMakeLists.txt (100%)
rename mlir/lib/Dialect/{WebAssemblySSA => WasmSSA}/IR/CMakeLists.txt (52%)
rename mlir/lib/Dialect/{WebAssemblySSA/IR/WebAssemblySSADialect.cpp => WasmSSA/IR/WasmSSADialect.cpp} (72%)
rename mlir/lib/Dialect/{WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp => WasmSSA/IR/WasmSSAInterfaces.cpp} (62%)
rename mlir/lib/Dialect/{WebAssemblySSA/IR/WebAssemblySSAOps.cpp => WasmSSA/IR/WasmSSAOps.cpp} (98%)
rename mlir/lib/Dialect/{WebAssemblySSA/IR/WebAssemblySSATypes.cpp => WasmSSA/IR/WasmSSATypes.cpp} (84%)
rename mlir/test/Dialect/{WebAssemblySSA => WasmSSA}/custom_parser/global-illegal.mlir (100%)
rename mlir/test/Dialect/{WebAssemblySSA => WasmSSA}/custom_parser/global.mlir (100%)
rename mlir/test/Dialect/{WebAssemblySSA => WasmSSA}/custom_parser/import.mlir (100%)
rename mlir/test/Dialect/{WebAssemblySSA => WasmSSA}/custom_parser/local.mlir (100%)
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index eb6075ac76c85..9e45214c784fc 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -41,6 +41,6 @@ add_subdirectory(Transform)
add_subdirectory(UB)
add_subdirectory(Utils)
add_subdirectory(Vector)
-add_subdirectory(WebAssemblySSA)
+add_subdirectory(WasmSSA)
add_subdirectory(X86Vector)
add_subdirectory(XeGPU)
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/CMakeLists.txt b/mlir/include/mlir/Dialect/WasmSSA/CMakeLists.txt
similarity index 100%
rename from mlir/include/mlir/Dialect/WebAssemblySSA/CMakeLists.txt
rename to mlir/include/mlir/Dialect/WasmSSA/CMakeLists.txt
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/WasmSSA/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..28823947230f6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+set(LLVM_TARGET_DEFINITIONS WasmSSATypes.td)
+mlir_tablegen(WasmSSATypeConstraints.h.inc -gen-type-constraint-decls)
+mlir_tablegen(WasmSSATypeConstraints.cpp.inc -gen-type-constraint-defs)
+
+set (LLVM_TARGET_DEFINITIONS WasmSSAInterfaces.td)
+mlir_tablegen(WasmSSAInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(WasmSSAInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRWasmSSAInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS WasmSSAOps.td)
+
+add_mlir_dialect(WasmSSAOps wasmssa)
+add_mlir_doc(WasmSSAOps WasmSSAOps Dialects/ -gen-dialect-doc)
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSA.h
similarity index 72%
rename from mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h
rename to mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSA.h
index 816f7ef008d4a..64391d807c633 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSA.h
@@ -1,4 +1,4 @@
-//===- WebAssemblySSA.h - WebAssemblySSA dialect ------------------*- C++-*-==//
+//===- WasmSSA.h - WasmSSA dialect ------------------*- C++-*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSA_H_
-#define MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSA_H_
+#ifndef MLIR_DIALECT_WasmSSA_IR_WasmSSA_H_
+#define MLIR_DIALECT_WasmSSA_IR_WasmSSA_H_
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"
@@ -16,20 +16,20 @@
// WebAssemblyDialect
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOpsDialect.h.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsDialect.h.inc"
//===----------------------------------------------------------------------===//
// WebAssembly Dialect Types
//===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOpsTypes.h.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsTypes.h.inc"
//===----------------------------------------------------------------------===//
// WebAssembly Interfaces
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h"
//===----------------------------------------------------------------------===//
// WebAssembly Dialect Operations
@@ -45,11 +45,11 @@
namespace mlir {
namespace wasmssa {
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypeConstraints.h.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSATypeConstraints.h.inc"
}
} // namespace mlir
#define GET_OP_CLASSES
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.h.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.h.inc"
-#endif // MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSA_H_
+#endif // MLIR_DIALECT_WasmSSA_IR_WasmSSA_H_
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSABase.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSABase.td
similarity index 79%
rename from mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSABase.td
rename to mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSABase.td
index cdc4d4864344f..f2777a7b155ed 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSABase.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSABase.td
@@ -1,4 +1,4 @@
-//===- WebAssemblySSABase.td - Base defs for wasmssa dialect -*- tablegen -*-==//
+//===- WasmSSABase.td - Base defs for wasmssa dialect -*- tablegen -*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef WEBASSEMBLYSSA_BASE
-#define WEBASSEMBLYSSA_BASE
+#ifndef WasmSSA_BASE
+#define WasmSSA_BASE
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
@@ -22,4 +22,4 @@ def WasmSSA_Dialect : Dialect {
let useDefaultTypePrinterParser = true;
}
-#endif //WEBASSEMBLYSSA_BASE
+#endif //WasmSSA_BASE
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
similarity index 56%
rename from mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h
rename to mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
index f50ecd9e3910b..b0986635605a5 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
@@ -1,4 +1,4 @@
-//===- WebAssemblySSAInterfaces.h - WebAssemblySSA Interfaces ---*- C++ -*-===//
+//===- WasmSSAInterfaces.h - WasmSSA Interfaces ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
//
-// This file defines op interfaces for the WebAssemblySSA dialect in MLIR.
+// This file defines op interfaces for the WasmSSA dialect in MLIR.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSAINTERFACES_H_
-#define MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSAINTERFACES_H_
+#ifndef MLIR_DIALECT_WasmSSA_IR_WasmSSAINTERFACES_H_
+#define MLIR_DIALECT_WasmSSA_IR_WasmSSAINTERFACES_H_
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"
@@ -19,21 +19,21 @@
namespace mlir::wasmssa {
namespace detail {
LogicalResult verifyConstantExpressionInterface(Operation *op);
-LogicalResult verifyWasmSSALabelBranchingInterface(Operation *op);
+LogicalResult verifyWasmSSALabelBranchingOpInterface(Operation *op);
} // namespace detail
template <class OperationType>
-struct AlwaysValidConstantExprTrait
- : public OpTrait::TraitBase<OperationType, AlwaysValidConstantExprTrait> {};
+struct AlwaysValidConstantExprOpTrait
+ : public OpTrait::TraitBase<OperationType, AlwaysValidConstantExprOpTrait> {};
template<typename OpType>
-struct ConstantExpressionInitializerTrait : public OpTrait::TraitBase<OpType, ConstantExpressionInitializerTrait>{
+struct ConstantExpressionInitializerOpTrait : public OpTrait::TraitBase<OpType, ConstantExpressionInitializerOpTrait>{
static LogicalResult verifyTrait(Operation* op) {
return detail::verifyConstantExpressionInterface(op);
}
};
} // namespace mlir::wasmssa
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h.inc"
-#endif // MLIR_DIALECT_WEBASSEMBLYSSA_IR_WEBASSEMBLYSSAINTERFACES_H_
+#endif // MLIR_DIALECT_WasmSSA_IR_WasmSSAINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
similarity index 80%
rename from mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
rename to mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
index 6a7e451d94568..6e1239596f1d2 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
@@ -1,4 +1,4 @@
-//===-- WebAssemblySSAInterfaces.td - WasmSSA Interfaces -*- tablegen -*--===//
+//===-- WasmSSAInterfaces.td - WasmSSA Interfaces -*- tablegen -*--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,17 +6,17 @@
//
//===----------------------------------------------------------------------===//
//
-// This file defines interfaces for the WebAssemblySSA dialect in MLIR.
+// This file defines interfaces for the WasmSSA dialect in MLIR.
//
//===----------------------------------------------------------------------===//
-#ifndef WEBASSEMBLYSSA_INTERFACES
-#define WEBASSEMBLYSSA_INTERFACES
+#ifndef WasmSSA_INTERFACES
+#define WasmSSA_INTERFACES
include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributes.td"
-def WasmSSALabelLevelOpInterface : OpInterface<"WasmSSALabelLevelInterface"> {
+def WasmSSALabelLevelOpInterface : OpInterface<"WasmSSALabelLevelOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
Operation that defines one level of nesting for wasm branching.
@@ -32,7 +32,7 @@ def WasmSSALabelLevelOpInterface : OpInterface<"WasmSSALabelLevelInterface"> {
];
}
-def WasmSSALabelBranchingOpInterface : OpInterface<"WasmSSALabelBranchingInterface"> {
+def WasmSSALabelBranchingOpInterface : OpInterface<"WasmSSALabelBranchingOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
Wasm operation that targets a label for a jump.
@@ -46,11 +46,11 @@ def WasmSSALabelBranchingOpInterface : OpInterface<"WasmSSALabelBranchingInterfa
>,
InterfaceMethod<
/*desc=*/ "Returns the destination of this operation",
- /*returnType=*/ "WasmSSALabelLevelInterface",
+ /*returnType=*/ "WasmSSALabelLevelOpInterface",
/*methodName=*/ "getTargetOp",
/*args=*/ (ins),
/*methodBody=*/ [{
- return *WasmSSALabelBranchingInterface::getTargetOpFromBlock($_op.getOperation()->getBlock(), $_op.getExitLevel());
+ return *WasmSSALabelBranchingOpInterface::getTargetOpFromBlock($_op.getOperation()->getBlock(), $_op.getExitLevel());
}]
>,
InterfaceMethod<
@@ -60,15 +60,15 @@ def WasmSSALabelBranchingOpInterface : OpInterface<"WasmSSALabelBranchingInterfa
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- auto op = mlir::cast<WasmSSALabelBranchingInterface>(this->getOperation());
+ auto op = mlir::cast<WasmSSALabelBranchingOpInterface>(this->getOperation());
return op.getTargetOp().getLabelTarget();
}]
>
];
let extraClassDeclaration = [{
- static ::llvm::FailureOr<WasmSSALabelLevelInterface> getTargetOpFromBlock(::mlir::Block *block, uint32_t level);
+ static ::llvm::FailureOr<WasmSSALabelLevelOpInterface> getTargetOpFromBlock(::mlir::Block *block, uint32_t level);
}];
- let verify = [{return verifyWasmSSALabelBranchingInterface($_op);}];
+ let verify = [{return verifyWasmSSALabelBranchingOpInterface($_op);}];
}
def WasmSSAImportOpInterface : OpInterface<"WasmSSAImportOpInterface"> {
@@ -113,12 +113,12 @@ def WasmSSAImportOpInterface : OpInterface<"WasmSSAImportOpInterface"> {
];
}
-def ConstantExpressionInitializerTrait : NativeOpTrait<"ConstantExpressionInitializerTrait"> {
+def ConstantExpressionInitializerOpTrait : NativeOpTrait<"ConstantExpressionInitializerOpTrait"> {
let cppNamespace = "::mlir::wasmssa";
}
-def ConstantExprCheckInterface :
- OpInterface<"ConstantExprCheckInterface"> {
+def ConstantExprCheckOpInterface :
+ OpInterface<"ConstantExprCheckOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
Base interface for operations that can be used in a Wasm Constant Expression.
@@ -136,7 +136,7 @@ def ConstantExprCheckInterface :
];
}
-def AlwaysValidConstantExprTrait : NativeOpTrait<"AlwaysValidConstantExprTrait", [], [{
+def AlwaysValidInConstantExprOpTrait : NativeOpTrait<"AlwaysValidConstantExprOpTrait", [], [{
::mlir::LogicalResult isValidInConstantExpr() {
return success();
}
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
similarity index 97%
rename from mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
rename to mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
index aeb82e4932611..3c9d65ac57cb5 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
@@ -1,4 +1,4 @@
-//===- WebAssemblySSAOps.td - WebAssemblySSA op definitions -*- tablegen -*-===//
+//===- WasmSSAOps.td - WasmSSA op definitions -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,13 @@
//
//===----------------------------------------------------------------------===//
-#ifndef WEBASSEMBLYSSA_OPS
-#define WEBASSEMBLYSSA_OPS
+#ifndef WasmSSA_OPS
+#define WasmSSA_OPS
-include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSABase.td"
-include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td"
-include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.td"
+include "mlir/Dialect/WasmSSA/IR/WasmSSABase.td"
+include "mlir/Dialect/WasmSSA/IR/WasmSSATypes.td"
+include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -70,8 +70,8 @@ def WasmSSA_BranchIfOp : WasmSSA_Op<"branch_if", [
def WasmSSA_ConstOp : WasmSSA_Op<"const", [
AllTypesMatch<["value", "result"]>,
- ConstantExprCheckInterface,
- AlwaysValidConstantExprTrait]> {
+ ConstantExprCheckOpInterface,
+ AlwaysValidInConstantExprOpTrait]> {
let summary = "Operator that represents a constant value";
let arguments = (ins TypedAttrInterface: $value);
let results = (outs WasmSSA_NumericType: $result);
@@ -177,7 +177,7 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
def WasmSSA_GlobalOp : WasmSSA_Op<"global", [
AffineScope, AutomaticAllocationScope,
- IsolatedFromAbove, Symbol, ConstantExpressionInitializerTrait]> {
+ IsolatedFromAbove, Symbol, ConstantExpressionInitializerOpTrait]> {
let summary= "WebAssembly global value";
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_ValTypeAttr: $type,
@@ -220,7 +220,7 @@ def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [
let hasCustomAssemblyFormat = 1;
}
-def WasmSSA_GlobalGetOp : WasmSSA_Op<"global_get", [DeclareOpInterfaceMethods<ConstantExprCheckInterface>]> {
+def WasmSSA_GlobalGetOp : WasmSSA_Op<"global_get", [DeclareOpInterfaceMethods<ConstantExprCheckOpInterface>]> {
let summary = "Returns the value of the global passed as argument.";
let arguments = (ins FlatSymbolRefAttr: $global);
let results = (outs WasmSSA_ValType: $global_val);
@@ -675,4 +675,4 @@ def WasmSSA_PopCntOp : WasmSSA_UnaryNumericalOp<"popcnt",
[WasmSSA_IntegerType]>{}
-#endif // WEBASSEMBLYSSA_OPS
+#endif // WasmSSA_OPS
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSATypes.td
similarity index 92%
rename from mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td
rename to mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSATypes.td
index e5ddcb743b034..946362e609fed 100644
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSATypes.td
@@ -1,4 +1,4 @@
-//===- WebAssemblySSATypes.td - WebAssemblySSA types def ----*- tablegen -*-===//
+//===- WasmSSATypes.td - WasmSSA types def ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#ifndef WEBASSEMBLYSSA_TYPES
-#define WEBASSEMBLYSSA_TYPES
+#ifndef WasmSSA_TYPES
+#define WasmSSA_TYPES
-include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSABase.td"
+include "mlir/Dialect/WasmSSA/IR/WasmSSABase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypes.td"
@@ -81,4 +81,4 @@ def WasmSSA_IntegerAttr : AnyAttrOf<[I32Attr, I64Attr]>;
def WasmSSA_FPAttr : AnyAttrOf<[F32Attr, F64Attr]>;
def WasmSSA_NumericAttr : AnyAttrOf<[WasmSSA_IntegerAttr, WasmSSA_FPAttr]>;
-#endif // WEBASSEMBLYSSA_TYPES
+#endif // WasmSSA_TYPES
diff --git a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/WebAssemblySSA/IR/CMakeLists.txt
deleted file mode 100644
index fa41a0caabf91..0000000000000
--- a/mlir/include/mlir/Dialect/WebAssemblySSA/IR/CMakeLists.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-set(LLVM_TARGET_DEFINITIONS WebAssemblySSATypes.td)
-mlir_tablegen(WebAssemblySSATypeConstraints.h.inc -gen-type-constraint-decls)
-mlir_tablegen(WebAssemblySSATypeConstraints.cpp.inc -gen-type-constraint-defs)
-
-set (LLVM_TARGET_DEFINITIONS WebAssemblySSAInterfaces.td)
-mlir_tablegen(WebAssemblySSAInterfaces.h.inc -gen-op-interface-decls)
-mlir_tablegen(WebAssemblySSAInterfaces.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRWebAssemblySSAInterfacesIncGen)
-
-set(LLVM_TARGET_DEFINITIONS WebAssemblySSAOps.td)
-
-add_mlir_dialect(WebAssemblySSAOps wasmssa)
-add_mlir_doc(WebAssemblySSAOps WebAssemblySSAOps Dialects/ -gen-dialect-doc)
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 5f3676a25d561..a33458440c457 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -96,7 +96,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Dialect.h"
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index b0403783d0752..9f15156236907 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -41,7 +41,7 @@ add_subdirectory(Transform)
add_subdirectory(UB)
add_subdirectory(Utils)
add_subdirectory(Vector)
-add_subdirectory(WebAssemblySSA)
+add_subdirectory(WasmSSA)
add_subdirectory(X86Vector)
add_subdirectory(XeGPU)
diff --git a/mlir/lib/Dialect/WebAssemblySSA/CMakeLists.txt b/mlir/lib/Dialect/WasmSSA/CMakeLists.txt
similarity index 100%
rename from mlir/lib/Dialect/WebAssemblySSA/CMakeLists.txt
rename to mlir/lib/Dialect/WasmSSA/CMakeLists.txt
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/CMakeLists.txt b/mlir/lib/Dialect/WasmSSA/IR/CMakeLists.txt
similarity index 52%
rename from mlir/lib/Dialect/WebAssemblySSA/IR/CMakeLists.txt
rename to mlir/lib/Dialect/WasmSSA/IR/CMakeLists.txt
index b106b8b7c2264..9fc2d7b87abd7 100644
--- a/mlir/lib/Dialect/WebAssemblySSA/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/WasmSSA/IR/CMakeLists.txt
@@ -1,15 +1,15 @@
add_mlir_dialect_library(MLIRWasmSSADialect
- WebAssemblySSAOps.cpp
- WebAssemblySSADialect.cpp
- WebAssemblySSAInterfaces.cpp
- WebAssemblySSATypes.cpp
+ WasmSSAOps.cpp
+ WasmSSADialect.cpp
+ WasmSSAInterfaces.cpp
+ WasmSSATypes.cpp
ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/WebAssemblySSA
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/WasmSSA
DEPENDS
- MLIRWebAssemblySSAOpsIncGen
- MLIRWebAssemblySSAInterfacesIncGen
+ MLIRWasmSSAOpsIncGen
+ MLIRWasmSSAInterfacesIncGen
LINK_LIBS PUBLIC
MLIRCastInterfaces
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSADialect.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSADialect.cpp
similarity index 72%
rename from mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSADialect.cpp
rename to mlir/lib/Dialect/WasmSSA/IR/WasmSSADialect.cpp
index a37e77256d970..98c3555b5324e 100644
--- a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSADialect.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSADialect.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -17,22 +17,22 @@
using namespace mlir;
using namespace mlir::wasmssa;
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOpsDialect.cpp.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsDialect.cpp.inc"
//===----------------------------------------------------------------------===//
// TableGen'd types definitions
//===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOpsTypes.cpp.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsTypes.cpp.inc"
void wasmssa::WasmSSADialect::initialize() {
addOperations<
#define GET_OP_LIST
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOpsTypes.cpp.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsTypes.cpp.inc"
>();
}
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
similarity index 62%
rename from mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp
rename to mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
index e6c0957dd449c..f7e556353fda2 100644
--- a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
@@ -1,4 +1,4 @@
-//===- WebAssemblySSAInterfaces.cpp - WebAssemblySSA Interfaces -*- C++ -*-===//
+//===- WasmSSAInterfaces.cpp - WasmSSA Interfaces -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,24 +6,24 @@
//
//===----------------------------------------------------------------------===//
//
-// This file defines op interfaces for the WebAssemblySSA dialect in MLIR.
+// This file defines op interfaces for the WasmSSA dialect in MLIR.
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h"
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
namespace mlir::wasmssa {
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.cpp.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp.inc"
namespace detail {
-LogicalResult verifyWasmSSALabelBranchingInterface(Operation *op) {
- auto branchInterface = dyn_cast<WasmSSALabelBranchingInterface>(op);
- llvm::FailureOr<WasmSSALabelLevelInterface> res =
- WasmSSALabelBranchingInterface::getTargetOpFromBlock(
+LogicalResult verifyWasmSSALabelBranchingOpInterface(Operation *op) {
+ auto branchInterface = dyn_cast<WasmSSALabelBranchingOpInterface>(op);
+ llvm::FailureOr<WasmSSALabelLevelOpInterface> res =
+ WasmSSALabelBranchingOpInterface::getTargetOpFromBlock(
op->getBlock(), branchInterface.getExitLevel());
return success(succeeded(res));
}
@@ -35,7 +35,7 @@ LogicalResult verifyConstantExpressionInterface(Operation *op) {
if (isa<ReturnOp>(currentOp))
return WalkResult::advance();
if (auto interfaceOp =
- dyn_cast<ConstantExprCheckInterface>(currentOp)) {
+ dyn_cast<ConstantExprCheckOpInterface>(currentOp)) {
if (interfaceOp.isValidInConstantExpr().succeeded())
return WalkResult::advance();
}
@@ -47,12 +47,12 @@ LogicalResult verifyConstantExpressionInterface(Operation *op) {
}
} // namespace detail
-llvm::FailureOr<WasmSSALabelLevelInterface>
-WasmSSALabelBranchingInterface::getTargetOpFromBlock(::mlir::Block *block,
+llvm::FailureOr<WasmSSALabelLevelOpInterface>
+WasmSSALabelBranchingOpInterface::getTargetOpFromBlock(::mlir::Block *block,
uint32_t breakLevel) {
- WasmSSALabelLevelInterface res{};
+ WasmSSALabelLevelOpInterface res{};
for (size_t curLevel{0}; curLevel <= breakLevel; curLevel++) {
- res = dyn_cast_or_null<WasmSSALabelLevelInterface>(block->getParentOp());
+ res = dyn_cast_or_null<WasmSSALabelLevelOpInterface>(block->getParentOp());
if (!res)
return failure();
block = res->getBlock();
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
similarity index 98%
rename from mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp
rename to mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
index bb11b4efab78a..60a07899406f5 100644
--- a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
@@ -1,4 +1,4 @@
-//===- WebAssemblySSAOps.cpp - WasmSSA dialect operations ----------------===//
+//===- WasmSSAOps.cpp - WasmSSA dialect operations ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===---------------------------------------------------------------------===//
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAInterfaces.h"
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -23,7 +23,7 @@
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSAOps.cpp.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.cpp.inc"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
@@ -87,7 +87,7 @@ Block *BlockOp::getLabelTarget() { return getTarget(); }
std::size_t BlockReturnOp::getExitLevel() { return 0; }
Block *BlockReturnOp::getTarget() {
- return cast<WasmSSALabelBranchingInterface>(getOperation())
+ return cast<WasmSSALabelBranchingOpInterface>(getOperation())
.getTargetOp()
.getOperation()
->getSuccessor(0);
diff --git a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp
similarity index 84%
rename from mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.cpp
rename to mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp
index f5f5d80c09ab9..db8780a2cb513 100644
--- a/mlir/lib/Dialect/WebAssemblySSA/IR/WebAssemblySSATypes.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp
@@ -1,4 +1,4 @@
-//===- WebAssemblySSAOps.cpp - WasmSSA dialect operations ----------------===//
+//===- WasmSSAOps.cpp - WasmSSA dialect operations ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===---------------------------------------------------------------------===//
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSA.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/LogicalResult.h"
@@ -14,7 +14,7 @@
#include <optional>
namespace mlir::wasmssa {
-#include "mlir/Dialect/WebAssemblySSA/IR/WebAssemblySSATypeConstraints.cpp.inc"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSATypeConstraints.cpp.inc"
} // namespace mlir::wasmssa
using namespace mlir;
diff --git a/mlir/test/Dialect/WebAssemblySSA/custom_parser/global-illegal.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/global-illegal.mlir
similarity index 100%
rename from mlir/test/Dialect/WebAssemblySSA/custom_parser/global-illegal.mlir
rename to mlir/test/Dialect/WasmSSA/custom_parser/global-illegal.mlir
diff --git a/mlir/test/Dialect/WebAssemblySSA/custom_parser/global.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir
similarity index 100%
rename from mlir/test/Dialect/WebAssemblySSA/custom_parser/global.mlir
rename to mlir/test/Dialect/WasmSSA/custom_parser/global.mlir
diff --git a/mlir/test/Dialect/WebAssemblySSA/custom_parser/import.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir
similarity index 100%
rename from mlir/test/Dialect/WebAssemblySSA/custom_parser/import.mlir
rename to mlir/test/Dialect/WasmSSA/custom_parser/import.mlir
diff --git a/mlir/test/Dialect/WebAssemblySSA/custom_parser/local.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir
similarity index 100%
rename from mlir/test/Dialect/WebAssemblySSA/custom_parser/local.mlir
rename to mlir/test/Dialect/WasmSSA/custom_parser/local.mlir
>From 881ec8e922f42100e6c489a58ebdd80a7404b6d8 Mon Sep 17 00:00:00 2001
From: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Date: Wed, 23 Jul 2025 13:26:07 +0900
Subject: [PATCH 06/11] Rename interfaces without WasmSSA since the namespace
is added via cppNamespace
---
.../mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h | 2 +-
.../mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td | 16 ++++++++--------
.../mlir/Dialect/WasmSSA/IR/WasmSSAOps.td | 16 ++++++++--------
.../lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp | 16 ++++++++--------
mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp | 2 +-
5 files changed, 26 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
index b0986635605a5..d14009784e167 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
@@ -19,7 +19,7 @@
namespace mlir::wasmssa {
namespace detail {
LogicalResult verifyConstantExpressionInterface(Operation *op);
-LogicalResult verifyWasmSSALabelBranchingOpInterface(Operation *op);
+LogicalResult verifyLabelBranchingOpInterface(Operation *op);
} // namespace detail
template <class OperationType>
struct AlwaysValidConstantExprOpTrait
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
index 6e1239596f1d2..5be5fb6a69894 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
@@ -16,7 +16,7 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributes.td"
-def WasmSSALabelLevelOpInterface : OpInterface<"WasmSSALabelLevelOpInterface"> {
+def LabelLevelOpInterface : OpInterface<"LabelLevelOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
Operation that defines one level of nesting for wasm branching.
@@ -32,7 +32,7 @@ def WasmSSALabelLevelOpInterface : OpInterface<"WasmSSALabelLevelOpInterface"> {
];
}
-def WasmSSALabelBranchingOpInterface : OpInterface<"WasmSSALabelBranchingOpInterface"> {
+def LabelBranchingOpInterface : OpInterface<"LabelBranchingOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
Wasm operation that targets a label for a jump.
@@ -46,11 +46,11 @@ def WasmSSALabelBranchingOpInterface : OpInterface<"WasmSSALabelBranchingOpInter
>,
InterfaceMethod<
/*desc=*/ "Returns the destination of this operation",
- /*returnType=*/ "WasmSSALabelLevelOpInterface",
+ /*returnType=*/ "LabelLevelOpInterface",
/*methodName=*/ "getTargetOp",
/*args=*/ (ins),
/*methodBody=*/ [{
- return *WasmSSALabelBranchingOpInterface::getTargetOpFromBlock($_op.getOperation()->getBlock(), $_op.getExitLevel());
+ return *LabelBranchingOpInterface::getTargetOpFromBlock($_op.getOperation()->getBlock(), $_op.getExitLevel());
}]
>,
InterfaceMethod<
@@ -60,18 +60,18 @@ def WasmSSALabelBranchingOpInterface : OpInterface<"WasmSSALabelBranchingOpInter
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- auto op = mlir::cast<WasmSSALabelBranchingOpInterface>(this->getOperation());
+ auto op = mlir::cast<LabelBranchingOpInterface>(this->getOperation());
return op.getTargetOp().getLabelTarget();
}]
>
];
let extraClassDeclaration = [{
- static ::llvm::FailureOr<WasmSSALabelLevelOpInterface> getTargetOpFromBlock(::mlir::Block *block, uint32_t level);
+ static ::llvm::FailureOr<LabelLevelOpInterface> getTargetOpFromBlock(::mlir::Block *block, uint32_t level);
}];
- let verify = [{return verifyWasmSSALabelBranchingOpInterface($_op);}];
+ let verify = [{return verifyLabelBranchingOpInterface($_op);}];
}
-def WasmSSAImportOpInterface : OpInterface<"WasmSSAImportOpInterface"> {
+def ImportOpInterface : OpInterface<"ImportOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
Operation that imports a symbol from an external wasm module;
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
index 3c9d65ac57cb5..5d4b3131f1e16 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
@@ -27,7 +27,7 @@ class WasmSSA_Op<string mnemonic, list<Trait> traits = []> :
Op<WasmSSA_Dialect, mnemonic, traits>;
class WasmSSA_BlockLikeOp<string mnemonic, string summaryStr> :
- WasmSSA_Op<mnemonic, [Terminator, DeclareOpInterfaceMethods<WasmSSALabelLevelOpInterface>]> {
+ WasmSSA_Op<mnemonic, [Terminator, DeclareOpInterfaceMethods<LabelLevelOpInterface>]> {
let summary = summaryStr;
let arguments = (ins Variadic<WasmSSA_ValType>: $inputs);
let regions = (region AnyRegion: $body);
@@ -48,7 +48,7 @@ def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<"block", "Create a nesting level"> {}
def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<"loop", "Create a nesting level similar to Block Op, except that it has itself as a successor."> {}
def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
- DeclareOpInterfaceMethods<WasmSSALabelBranchingOpInterface>]> {
+ DeclareOpInterfaceMethods<LabelBranchingOpInterface>]> {
let summary = "Return from the current block";
let arguments = (ins Variadic<WasmSSA_ValType>: $inputs);
let extraClassDeclaration = [{
@@ -59,7 +59,7 @@ def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
def WasmSSA_BranchIfOp : WasmSSA_Op<"branch_if", [
Terminator,
- DeclareOpInterfaceMethods<WasmSSALabelBranchingOpInterface>]> {
+ DeclareOpInterfaceMethods<LabelBranchingOpInterface>]> {
let summary = "Jump to target level if condition has non-zero value";
let arguments = (ins I32: $condition,
UI32Attr: $exitLevel,
@@ -144,7 +144,7 @@ def WasmSSA_FuncCallOp : WasmSSA_Op<"call"> {
def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
Symbol,
CallableOpInterface,
- WasmSSAImportOpInterface]> {
+ ImportOpInterface]> {
let summary = "Importing a function variable";
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
@@ -199,7 +199,7 @@ def WasmSSA_GlobalOp : WasmSSA_Op<"global", [
def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [
Symbol,
- WasmSSAImportOpInterface]> {
+ ImportOpInterface]> {
let summary = "Importing a global variable";
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
@@ -228,7 +228,7 @@ def WasmSSA_GlobalGetOp : WasmSSA_Op<"global_get", [DeclareOpInterfaceMethods<Co
}
def WasmSSA_IfOp : WasmSSA_Op<"if", [Terminator,
- DeclareOpInterfaceMethods<WasmSSALabelLevelOpInterface>]> {
+ DeclareOpInterfaceMethods<LabelLevelOpInterface>]> {
let summary = "Execute the if region if condition value is nonzero, the else region otherwise.";
let arguments = (ins I32:$condition, Variadic<WasmSSA_ValType>: $inputs);
let regions = (region AnyRegion: $if, AnyRegion: $else);
@@ -302,7 +302,7 @@ def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> {
];
}
-def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, WasmSSAImportOpInterface]> {
+def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> {
let summary = "Importing a memory";
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
@@ -330,7 +330,7 @@ def WasmSSA_TableOp : WasmSSA_Op<"table", [Symbol]> {
"wasmssa::TableType":$type)>];
}
-def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, WasmSSAImportOpInterface]> {
+def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterface]> {
let summary = "Importing a table";
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
index f7e556353fda2..e37fdbfe9e362 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
@@ -20,10 +20,10 @@ namespace mlir::wasmssa {
#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp.inc"
namespace detail {
-LogicalResult verifyWasmSSALabelBranchingOpInterface(Operation *op) {
- auto branchInterface = dyn_cast<WasmSSALabelBranchingOpInterface>(op);
- llvm::FailureOr<WasmSSALabelLevelOpInterface> res =
- WasmSSALabelBranchingOpInterface::getTargetOpFromBlock(
+LogicalResult verifyLabelBranchingOpInterface(Operation *op) {
+ auto branchInterface = dyn_cast<LabelBranchingOpInterface>(op);
+ llvm::FailureOr<LabelLevelOpInterface> res =
+ LabelBranchingOpInterface::getTargetOpFromBlock(
op->getBlock(), branchInterface.getExitLevel());
return success(succeeded(res));
}
@@ -47,12 +47,12 @@ LogicalResult verifyConstantExpressionInterface(Operation *op) {
}
} // namespace detail
-llvm::FailureOr<WasmSSALabelLevelOpInterface>
-WasmSSALabelBranchingOpInterface::getTargetOpFromBlock(::mlir::Block *block,
+llvm::FailureOr<LabelLevelOpInterface>
+LabelBranchingOpInterface::getTargetOpFromBlock(::mlir::Block *block,
uint32_t breakLevel) {
- WasmSSALabelLevelOpInterface res{};
+ LabelLevelOpInterface res{};
for (size_t curLevel{0}; curLevel <= breakLevel; curLevel++) {
- res = dyn_cast_or_null<WasmSSALabelLevelOpInterface>(block->getParentOp());
+ res = dyn_cast_or_null<LabelLevelOpInterface>(block->getParentOp());
if (!res)
return failure();
block = res->getBlock();
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
index 60a07899406f5..80b7ae0051e7e 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
@@ -87,7 +87,7 @@ Block *BlockOp::getLabelTarget() { return getTarget(); }
std::size_t BlockReturnOp::getExitLevel() { return 0; }
Block *BlockReturnOp::getTarget() {
- return cast<WasmSSALabelBranchingOpInterface>(getOperation())
+ return cast<LabelBranchingOpInterface>(getOperation())
.getTargetOp()
.getOperation()
->getSuccessor(0);
>From 995b710486724ee5bd9e9431b8b957fa511a6bf2 Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Thu, 24 Jul 2025 12:54:10 +0900
Subject: [PATCH 07/11] [mlir][wasm] LabelLevelInterface requires IsTerminator
trait + doc
---
.../mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h | 6 ++++++
.../mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td | 14 +++++++++++++-
2 files changed, 19 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
index d14009784e167..1aa85efef0e3e 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
@@ -20,6 +20,12 @@ namespace mlir::wasmssa {
namespace detail {
LogicalResult verifyConstantExpressionInterface(Operation *op);
LogicalResult verifyLabelBranchingOpInterface(Operation *op);
+template <typename OpType>
+LogicalResult verifyLabelLevelInterface() {
+ static_assert(OpType::template hasTrait<::mlir::OpTrait::IsTerminator>(),
+ "LabelLevelOp should be terminator ops");
+ return success();
+}
} // namespace detail
template <class OperationType>
struct AlwaysValidConstantExprOpTrait
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
index 5be5fb6a69894..f91c9a08633e2 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
@@ -20,7 +20,14 @@ def LabelLevelOpInterface : OpInterface<"LabelLevelOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
Operation that defines one level of nesting for wasm branching.
- These operation region can be targeted by branch instructions.
+
+ These ops defines Wasm control flow nesting levels (Wasm Labels) that Wasm
+ branching operations can target.
+ The branching operations specify a number of nesting level they want to exit,
+ and are redirected to the target of the corresponding nesting LabelLevelOp.
+
+ As multiple level can be escaped at once, the level defining ops need themselves
+ to be `Terminator` ops.
}];
let methods = [
InterfaceMethod<
@@ -30,6 +37,10 @@ def LabelLevelOpInterface : OpInterface<"LabelLevelOpInterface"> {
/*args=*/ (ins)
>
];
+
+ let verify = [{
+ return verifyLabelLevelInterface<ConcreteOp>();
+ }];
}
def LabelBranchingOpInterface : OpInterface<"LabelBranchingOpInterface"> {
@@ -65,6 +76,7 @@ def LabelBranchingOpInterface : OpInterface<"LabelBranchingOpInterface"> {
}]
>
];
+
let extraClassDeclaration = [{
static ::llvm::FailureOr<LabelLevelOpInterface> getTargetOpFromBlock(::mlir::Block *block, uint32_t level);
}];
>From 318b65be0f29760e18a07df82db2b7477bdeb0c7 Mon Sep 17 00:00:00 2001
From: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Date: Thu, 24 Jul 2025 13:13:16 +0900
Subject: [PATCH 08/11] [mlir][wasm] Remove direct use of odsState in custom
builder
---
mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp | 50 ++++++----------------
1 file changed, 12 insertions(+), 38 deletions(-)
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
index 80b7ae0051e7e..edede740c2160 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
@@ -161,10 +161,7 @@ Block *FuncOp::addEntryBlock() {
void FuncOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, llvm::StringRef symbol,
FunctionType funcType) {
- odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
- odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
- odsState.addAttribute("functionType", TypeAttr::get(funcType));
- odsState.addRegion();
+ FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}, "nested");
}
ParseResult FuncOp::parse(::mlir::OpAsmParser &parser,
@@ -235,11 +232,8 @@ void FuncImportOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, StringRef symbol,
StringRef moduleName, StringRef importName,
FunctionType type) {
- odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
- odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
- odsState.addAttribute("moduleName", odsBuilder.getStringAttr(moduleName));
- odsState.addAttribute("importName", odsBuilder.getStringAttr(importName));
- odsState.addAttribute("type", TypeAttr::get(type));
+ FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
+ type, {}, {}, odsBuilder.getStringAttr("nested"));
}
//===----------------------------------------------------------------------===//
@@ -249,12 +243,8 @@ void FuncImportOp::build(::mlir::OpBuilder &odsBuilder,
void GlobalOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, llvm::StringRef symbol,
Type type, bool isMutable) {
- odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
- odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
- odsState.addAttribute("type", TypeAttr::get(type));
- if (isMutable)
- odsState.addAttribute("isMutable", odsBuilder.getUnitAttr());
- odsState.addRegion();
+ GlobalOp::build(odsBuilder, odsState, symbol, type, isMutable,
+ odsBuilder.getStringAttr("nested"));
}
// Custom formats
@@ -326,13 +316,7 @@ void GlobalImportOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, StringRef symbol,
StringRef moduleName, StringRef importName,
Type type, bool isMutable) {
- odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
- odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
- odsState.addAttribute("moduleName", odsBuilder.getStringAttr(moduleName));
- odsState.addAttribute("importName", odsBuilder.getStringAttr(importName));
- odsState.addAttribute("type", TypeAttr::get(type));
- if (isMutable)
- odsState.addAttribute("isMutable", odsBuilder.getUnitAttr());
+ GlobalImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, type, isMutable, odsBuilder.getStringAttr("nested"));
}
ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -449,9 +433,7 @@ Block *LoopOp::getLabelTarget() { return &getBody().front(); }
void MemOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, llvm::StringRef symbol,
LimitType limit) {
- odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
- odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
- odsState.addAttribute("limits", TypeAttr::get(limit));
+ MemOp::build(odsBuilder, odsState, symbol, limit, odsBuilder.getStringAttr("nested"));
}
//===----------------------------------------------------------------------===//
@@ -462,11 +444,8 @@ void MemImportOp::build(mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
llvm::StringRef symbol, llvm::StringRef moduleName,
llvm::StringRef importName, LimitType limits) {
- odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
- odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
- odsState.addAttribute("moduleName", odsBuilder.getStringAttr(moduleName));
- odsState.addAttribute("importName", odsBuilder.getStringAttr(importName));
- odsState.addAttribute("limits", TypeAttr::get(limits));
+ MemImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
+ limits, odsBuilder.getStringAttr("nested"));
}
//===----------------------------------------------------------------------===//
@@ -498,9 +477,7 @@ void ReturnOp::build(::mlir::OpBuilder &odsBuilder,
void TableOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, llvm::StringRef symbol,
TableType type) {
- odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
- odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
- odsState.addAttribute("type", TypeAttr::get(type));
+ TableOp::build(odsBuilder, odsState, symbol, type, odsBuilder.getStringAttr("nested"));
}
//===----------------------------------------------------------------------===//
@@ -511,9 +488,6 @@ void TableImportOp::build(mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
llvm::StringRef symbol, llvm::StringRef moduleName,
llvm::StringRef importName, TableType type) {
- odsState.addAttribute("sym_name", odsBuilder.getStringAttr(symbol));
- odsState.addAttribute("sym_visibility", odsBuilder.getStringAttr("nested"));
- odsState.addAttribute("moduleName", odsBuilder.getStringAttr(moduleName));
- odsState.addAttribute("importName", odsBuilder.getStringAttr(importName));
- odsState.addAttribute("type", TypeAttr::get(type));
+ TableImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
+ type, odsBuilder.getStringAttr("nested"));
}
>From 89b64f4a3343fe03882aaaf1186e80aafc37100a Mon Sep 17 00:00:00 2001
From: Luc Forget <luc.forget at woven.toyota>
Date: Thu, 24 Jul 2025 13:29:47 +0900
Subject: [PATCH 09/11] [mlir][wasm] Verifier for LabelLevelInterface Target
validity
---
.../mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h | 4 ++--
.../mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td | 4 +++-
mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp | 11 +++++++++++
3 files changed, 16 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
index 1aa85efef0e3e..a79fa1516a51b 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
@@ -21,17 +21,17 @@ namespace detail {
LogicalResult verifyConstantExpressionInterface(Operation *op);
LogicalResult verifyLabelBranchingOpInterface(Operation *op);
template <typename OpType>
-LogicalResult verifyLabelLevelInterface() {
+LogicalResult verifyLabelLevelInterfaceIsTerminator() {
static_assert(OpType::template hasTrait<::mlir::OpTrait::IsTerminator>(),
"LabelLevelOp should be terminator ops");
return success();
}
+LogicalResult verifyLabelLevelInterface(Operation *op);
} // namespace detail
template <class OperationType>
struct AlwaysValidConstantExprOpTrait
: public OpTrait::TraitBase<OperationType, AlwaysValidConstantExprOpTrait> {};
-
template<typename OpType>
struct ConstantExpressionInitializerOpTrait : public OpTrait::TraitBase<OpType, ConstantExpressionInitializerOpTrait>{
static LogicalResult verifyTrait(Operation* op) {
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
index f91c9a08633e2..1715aba7afd35 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
@@ -39,7 +39,9 @@ def LabelLevelOpInterface : OpInterface<"LabelLevelOpInterface"> {
];
let verify = [{
- return verifyLabelLevelInterface<ConcreteOp>();
+ return success(
+ succeeded(verifyLabelLevelInterfaceIsTerminator<ConcreteOp>()) &&
+ succeeded(verifyLabelLevelInterface($_op)));
}];
}
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
index e37fdbfe9e362..6dc079dc9d9cf 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/Support/LogicalResult.h"
namespace mlir::wasmssa {
#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp.inc"
@@ -45,6 +46,16 @@ LogicalResult verifyConstantExpressionInterface(Operation *op) {
});
return success(!resultState.wasInterrupted());
}
+
+LogicalResult verifyLabelLevelInterface(Operation *op) {
+ Block* target = cast<LabelLevelOpInterface>(op).getLabelTarget();
+ Region* targetRegion = target->getParent();
+ if (targetRegion != op->getParentRegion() ||
+ targetRegion->getParentOp() != op)
+ return op->emitError("target should be a block defined in same level than "
+ "operation or in its region.");
+ return success();
+}
} // namespace detail
llvm::FailureOr<LabelLevelOpInterface>
>From d0aed4d637297b0523ee65c1b106782194d3b10a Mon Sep 17 00:00:00 2001
From: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Date: Fri, 25 Jul 2025 15:55:14 +0900
Subject: [PATCH 10/11] Add tests for verifiers, improve reporting of verifiers
errors and document interfaces
---
.../Dialect/WasmSSA/IR/WasmSSAInterfaces.h | 15 +++++
.../Dialect/WasmSSA/IR/WasmSSAInterfaces.td | 7 ++-
.../mlir/Dialect/WasmSSA/IR/WasmSSAOps.td | 4 +-
.../mlir/Dialect/WasmSSA/IR/WasmSSATypes.td | 2 +-
.../Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp | 2 +-
mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp | 63 +++++--------------
mlir/test/Dialect/WasmSSA/extend-illegal.mlir | 18 ++++++
.../{custom_parser => }/global-illegal.mlir | 11 ++++
mlir/test/Dialect/WasmSSA/locals-illegal.mlir | 17 +++++
.../Dialect/WasmSSA/reinterpret-illegal.mlir | 17 +++++
10 files changed, 103 insertions(+), 53 deletions(-)
create mode 100644 mlir/test/Dialect/WasmSSA/extend-illegal.mlir
rename mlir/test/Dialect/WasmSSA/{custom_parser => }/global-illegal.mlir (73%)
create mode 100644 mlir/test/Dialect/WasmSSA/locals-illegal.mlir
create mode 100644 mlir/test/Dialect/WasmSSA/reinterpret-illegal.mlir
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
index a79fa1516a51b..518229ff62729 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
@@ -18,20 +18,35 @@
namespace mlir::wasmssa {
namespace detail {
+/// Verify that `op` conforms to the ConstantExpressionInterface.
+/// `op` must be initialized with valid constant expressions.
LogicalResult verifyConstantExpressionInterface(Operation *op);
+
+/// Verify that `op` conforms to the LabelBranchingOpInterface
+/// Checks that the branching is targetting something within its scope.
LogicalResult verifyLabelBranchingOpInterface(Operation *op);
+
+/// Verify that `op` conforms to LabelLevelInterfaceIsTerminator
template <typename OpType>
LogicalResult verifyLabelLevelInterfaceIsTerminator() {
static_assert(OpType::template hasTrait<::mlir::OpTrait::IsTerminator>(),
"LabelLevelOp should be terminator ops");
return success();
}
+
+/// Verify that `op` conforms to the LabelLevelInterface
+/// `op`'s target should defined at the same scope level.
LogicalResult verifyLabelLevelInterface(Operation *op);
} // namespace detail
+
+/// Operations implementing this trait are considered as valid
+/// constant expressions in any context (In contrast of ConstantExprCheckOpInterface
+/// which are sometimes considered valid constant expressions.
template <class OperationType>
struct AlwaysValidConstantExprOpTrait
: public OpTrait::TraitBase<OperationType, AlwaysValidConstantExprOpTrait> {};
+/// Trait used to verify operations that need a constant expression initializer.
template<typename OpType>
struct ConstantExpressionInitializerOpTrait : public OpTrait::TraitBase<OpType, ConstantExpressionInitializerOpTrait>{
static LogicalResult verifyTrait(Operation* op) {
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
index 1715aba7afd35..dc24452ba3c59 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
@@ -135,23 +135,24 @@ def ConstantExprCheckOpInterface :
OpInterface<"ConstantExprCheckOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
- Base interface for operations that can be used in a Wasm Constant Expression.
+ Interface for allowing to verify that operations can be used in a Wasm Constant Expression.
}];
let methods = [
InterfaceMethod<
/*desc=*/ [{
Returns success if the current operation is valid in a constant expression context.
+ A diagnostic is emitted on error.
}],
/*returnType=*/ "::mlir::LogicalResult",
- /*methodName=*/ "isValidInConstantExpr",
+ /*methodName=*/ "CheckValidInConstantExpr",
/*args=*/ (ins)
>
];
}
def AlwaysValidInConstantExprOpTrait : NativeOpTrait<"AlwaysValidConstantExprOpTrait", [], [{
- ::mlir::LogicalResult isValidInConstantExpr() {
+ ::mlir::LogicalResult CheckValidInConstantExpr() {
return success();
}
}]> {
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
index 5d4b3131f1e16..8086a4828675b 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
@@ -218,6 +218,7 @@ def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [
"bool": $isMutable)>
];
let hasCustomAssemblyFormat = 1;
+ //let assemblyFormat = "$importName `from` $moduleName `as` $sym_name oilist (`mutable` $isMutable | `vis` $sym_visibility) `:` $type attr-dict";
}
def WasmSSA_GlobalGetOp : WasmSSA_Op<"global_get", [DeclareOpInterfaceMethods<ConstantExprCheckOpInterface>]> {
@@ -269,7 +270,6 @@ def WasmSSA_LocalGetOp : WasmSSA_Op<"local_get", [
let arguments = (ins WasmSSA_LocalRef: $localVar);
let results = (outs WasmSSA_ValType: $result);
let assemblyFormat = "$localVar `:` type($localVar) attr-dict";
- let hasVerifier = 1;
}
def WasmSSA_LocalSetOp : WasmSSA_Op<"local_set"> {
@@ -598,7 +598,7 @@ def WasmSSA_ExtendLowBitsSOp : WasmSSA_Op<"extend", [AllTypesMatch<["input", "re
let arguments = (ins WasmSSA_IntegerType:$input, Builtin_IntegerAttr: $bitsToTake);
let results = (outs WasmSSA_IntegerType: $result);
let hasVerifier = 1;
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "$bitsToTake `low` `bits` `from` $input `:` type($input) attr-dict";
}
def WasmSSA_PromoteOp : WasmSSA_ConversionOp<"promote",
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSATypes.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSATypes.td
index 946362e609fed..f1f273a7c55a3 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSATypes.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSATypes.td
@@ -51,7 +51,7 @@ def WasmSSA_FuncType : TypeAlias<FunctionType>;
def WasmSSA_LimitType : WasmSSA_Type<"Limit", "limit"> {
let summary = "Wasm limit type";
- let parameters = (ins "uint32_t":$min,
+ let parameters = (ins "uint32_t":$min,
"std::optional<uint32_t>":$max);
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
index 6dc079dc9d9cf..88fd5fb540c15 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
@@ -37,7 +37,7 @@ LogicalResult verifyConstantExpressionInterface(Operation *op) {
return WalkResult::advance();
if (auto interfaceOp =
dyn_cast<ConstantExprCheckOpInterface>(currentOp)) {
- if (interfaceOp.isValidInConstantExpr().succeeded())
+ if (interfaceOp.CheckValidInConstantExpr().succeeded())
return WalkResult::advance();
}
op->emitError("expected a constant initializer for this operator, got ")
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
index edede740c2160..fd5a4e9ae5294 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
@@ -97,39 +97,6 @@ Block *BlockReturnOp::getTarget() {
// ExtendLowBitsSOp
//===----------------------------------------------------------------------===//
-ParseResult ExtendLowBitsSOp::parse(::mlir::OpAsmParser &parser,
- ::mlir::OperationState &result) {
- OpAsmParser::UnresolvedOperand operand;
- uint64_t nBits;
- ParseResult parseRes = parser.parseInteger(nBits);
- parseRes = parser.parseKeyword("low");
- parseRes = parser.parseKeyword("bits");
- parseRes = parser.parseKeyword("from");
- parseRes = parser.parseOperand(operand);
- parseRes = parser.parseColon();
- Type inType;
- parseRes = parser.parseType(inType);
- if (!inType.isInteger())
- return failure();
- llvm::SmallVector<Value, 1> opVal;
- parseRes = parser.resolveOperand(operand, inType, opVal);
- if (parseRes.failed())
- return failure();
- result.addOperands(opVal);
- result.addAttribute(
- ExtendLowBitsSOp::getBitsToTakeAttrName(OperationName{
- ExtendLowBitsSOp::getOperationName(), parser.getContext()}),
- parser.getBuilder().getI64IntegerAttr(nBits));
- result.addTypes(inType);
- return success();
-}
-
-void ExtendLowBitsSOp::print(OpAsmPrinter &p) {
- p << " " << getBitsToTake().getUInt() << " low bits from ";
- p.printOperand(getInput());
- p << ": " << getInput().getType();
-}
-
LogicalResult ExtendLowBitsSOp::verify() {
auto bitsToTake = getBitsToTake().getValue().getLimitedValue();
if (bitsToTake != 32 && bitsToTake != 16 && bitsToTake != 8)
@@ -139,7 +106,7 @@ LogicalResult ExtendLowBitsSOp::verify() {
if (bitsToTake >= getInput().getType().getIntOrFloatBitWidth())
return emitError("trying to extend the ")
<< bitsToTake << " low bits from a " << getInput().getType()
- << " value";
+ << " value is illegal";
return success();
}
@@ -292,14 +259,20 @@ void GlobalOp::print(OpAsmPrinter &printer) {
//===----------------------------------------------------------------------===//
// Custom interface overrides
-LogicalResult GlobalGetOp::isValidInConstantExpr() {
+LogicalResult GlobalGetOp::CheckValidInConstantExpr() {
StringRef referencedSymbol = getGlobal();
Operation *symTableOp =
getOperation()->getParentWithTrait<OpTrait::SymbolTable>();
+ if (!symTableOp)
+ return emitError(
+ "cannot find the symbol table associated with this operation");
+ // NOTE: Having to lookup the symbol inside the symbol table anytime the verifier
+ // is called can be costly. This may be improved with caching or another architecture
+ // for the constant checking mechanism.
Operation *definitionOp =
SymbolTable::lookupSymbolIn(symTableOp, referencedSymbol);
if (!definitionOp)
- return failure();
+ return emitError() << "symbol @" << referencedSymbol << " is undefined";
auto definitionImport = llvm::dyn_cast<GlobalImportOp>(definitionOp);
if (!definitionImport || definitionImport.getIsMutable()) {
return emitError("global.get op is considered constant if it's referring "
@@ -389,18 +362,14 @@ LogicalResult LocalGetOp::inferReturnTypes(
return inferTeeGetResType(operands, inferredReturnTypes);
}
-LogicalResult LocalGetOp::verify() {
- return success(getLocalVar().getType().getElementType() ==
- getResult().getType());
-}
-
//===----------------------------------------------------------------------===//
// LocalSetOp
//===----------------------------------------------------------------------===//
LogicalResult LocalSetOp::verify() {
- return success(getLocalVar().getType().getElementType() ==
- getValue().getType());
+ if (getLocalVar().getType().getElementType() != getValue().getType())
+ return emitError("input type and result type of local.set do not match");
+ return llvm::success();
}
//===----------------------------------------------------------------------===//
@@ -415,9 +384,11 @@ LogicalResult LocalTeeOp::inferReturnTypes(
}
LogicalResult LocalTeeOp::verify() {
- return success(getLocalVar().getType().getElementType() ==
- getValue().getType() &&
- getValue().getType() == getResult().getType());
+ if (getLocalVar().getType().getElementType() !=
+ getValue().getType() ||
+ getValue().getType() != getResult().getType())
+ return emitError("input type and output type of local.tee do not match");
+ return llvm::success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/WasmSSA/extend-illegal.mlir b/mlir/test/Dialect/WasmSSA/extend-illegal.mlir
new file mode 100644
index 0000000000000..8d782801c33f9
--- /dev/null
+++ b/mlir/test/Dialect/WasmSSA/extend-illegal.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+
+wasmssa.func nested @extend_low_64() -> i32 {
+ %0 = wasmssa.const 10 : i32
+ // expected-error at +1 {{extend op can only take 8, 16 or 32 bits. Got 64}}
+ %1 = wasmssa.extend 64 low bits from %0: i32
+ wasmssa.return %1 : i32
+}
+
+// -----
+
+wasmssa.func nested @extend_too_much() -> i32 {
+ %0 = wasmssa.const 10 : i32
+ // expected-error at +1 {{trying to extend the 32 low bits from a 'i32' value is illegal}}
+ %1 = wasmssa.extend 32 low bits from %0: i32
+ wasmssa.return %1 : i32
+}
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/global-illegal.mlir b/mlir/test/Dialect/WasmSSA/global-illegal.mlir
similarity index 73%
rename from mlir/test/Dialect/WasmSSA/custom_parser/global-illegal.mlir
rename to mlir/test/Dialect/WasmSSA/global-illegal.mlir
index 3571565564b6d..f376a2f410488 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/global-illegal.mlir
+++ b/mlir/test/Dialect/WasmSSA/global-illegal.mlir
@@ -21,3 +21,14 @@ module {
wasmssa.return %0 : i32
}
}
+
+// -----
+
+module {
+ // expected-error at +1 {{expected a constant initializer for this operator}}
+ wasmssa.global @global_1 i32 : {
+ // expected-error at +1 {{symbol @glarble is undefined}}
+ %0 = wasmssa.global_get @glarble : i32
+ wasmssa.return %0 : i32
+ }
+}
diff --git a/mlir/test/Dialect/WasmSSA/locals-illegal.mlir b/mlir/test/Dialect/WasmSSA/locals-illegal.mlir
new file mode 100644
index 0000000000000..35c590b36b289
--- /dev/null
+++ b/mlir/test/Dialect/WasmSSA/locals-illegal.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+wasmssa.func nested @local_set_err(%arg0: !wasmssa<local ref to i32>) -> i64 {
+ %0 = wasmssa.const 3 : i64
+ // expected-error at +1 {{input type and result type of local.set do not match}}
+ wasmssa.local_set %arg0 : ref to i32 to %0 : i64
+ wasmssa.return %0 : i64
+}
+
+// -----
+
+wasmssa.func nested @local_tee_err(%arg0: !wasmssa<local ref to i32>) -> i32 {
+ %0 = wasmssa.const 3 : i64
+ // expected-error at +1 {{input type and output type of local.tee do not match}}
+ %1 = wasmssa.local_tee %arg0 : ref to i32 to %0 : i64
+ wasmssa.return %1 : i32
+}
diff --git a/mlir/test/Dialect/WasmSSA/reinterpret-illegal.mlir b/mlir/test/Dialect/WasmSSA/reinterpret-illegal.mlir
new file mode 100644
index 0000000000000..3dd8509e56b35
--- /dev/null
+++ b/mlir/test/Dialect/WasmSSA/reinterpret-illegal.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+wasmssa.func @f32_reinterpret_f32() -> f32 {
+ %0 = wasmssa.const -1.000000e+00 : f32
+ // expected-error at +1 {{reinterpret input and output type should be distinct.}}
+ %1 = wasmssa.reinterpret %0 : f32 as f32
+ wasmssa.return %1 : f32
+}
+
+// -----
+
+wasmssa.func @f64_reinterpret_f32() -> f32 {
+ %0 = wasmssa.const -1.000000e+00 : f64
+ // expected-error at +1 {{input type ('f64') and output type ('f32') have incompatible bit widths.}}
+ %1 = wasmssa.reinterpret %0 : f64 as f32
+ wasmssa.return %1 : f32
+}
>From 11555aabc5fa9869d07ced8ecb2c1c1ed4d5c72e Mon Sep 17 00:00:00 2001
From: Ferdinand Lemaire <ferdinand.lemaire at woven-planet.global>
Date: Mon, 28 Jul 2025 10:59:12 +0900
Subject: [PATCH 11/11] [mlir][wasm] Remove contant expression interface and
make use of symbolUser to check for correct use of the a constant symbol
---
.../Dialect/WasmSSA/IR/WasmSSAInterfaces.h | 20 +++++++-------
.../Dialect/WasmSSA/IR/WasmSSAInterfaces.td | 26 +------------------
.../mlir/Dialect/WasmSSA/IR/WasmSSAOps.td | 7 +++--
.../mlir/Dialect/WasmSSA/IR/WasmSSATypes.td | 7 ++---
.../Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp | 14 ++++------
mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp | 21 +++++++--------
mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp | 26 -------------------
mlir/test/Dialect/WasmSSA/global-illegal.mlir | 2 --
8 files changed, 33 insertions(+), 90 deletions(-)
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
index 518229ff62729..d8948317477d0 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
@@ -40,18 +40,20 @@ LogicalResult verifyLabelLevelInterface(Operation *op);
} // namespace detail
/// Operations implementing this trait are considered as valid
-/// constant expressions in any context (In contrast of ConstantExprCheckOpInterface
-/// which are sometimes considered valid constant expressions.
+/// constant expressions in any context (In contrast of
+/// ConstantExprCheckOpInterface which are sometimes considered valid constant
+/// expressions.
template <class OperationType>
-struct AlwaysValidConstantExprOpTrait
- : public OpTrait::TraitBase<OperationType, AlwaysValidConstantExprOpTrait> {};
+struct ConstantExprOpTrait
+ : public OpTrait::TraitBase<OperationType, ConstantExprOpTrait> {};
/// Trait used to verify operations that need a constant expression initializer.
-template<typename OpType>
-struct ConstantExpressionInitializerOpTrait : public OpTrait::TraitBase<OpType, ConstantExpressionInitializerOpTrait>{
- static LogicalResult verifyTrait(Operation* op) {
- return detail::verifyConstantExpressionInterface(op);
- }
+template <typename OpType>
+struct ConstantExpressionInitializerOpTrait
+ : public OpTrait::TraitBase<OpType, ConstantExpressionInitializerOpTrait> {
+ static LogicalResult verifyTrait(Operation *op) {
+ return detail::verifyConstantExpressionInterface(op);
+ }
};
} // namespace mlir::wasmssa
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
index dc24452ba3c59..b586a08fc0f4d 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
@@ -131,31 +131,7 @@ def ConstantExpressionInitializerOpTrait : NativeOpTrait<"ConstantExpressionInit
let cppNamespace = "::mlir::wasmssa";
}
-def ConstantExprCheckOpInterface :
- OpInterface<"ConstantExprCheckOpInterface"> {
- let cppNamespace = "::mlir::wasmssa";
- let description = [{
- Interface for allowing to verify that operations can be used in a Wasm Constant Expression.
- }];
-
- let methods = [
- InterfaceMethod<
- /*desc=*/ [{
- Returns success if the current operation is valid in a constant expression context.
- A diagnostic is emitted on error.
- }],
- /*returnType=*/ "::mlir::LogicalResult",
- /*methodName=*/ "CheckValidInConstantExpr",
- /*args=*/ (ins)
- >
- ];
-}
-
-def AlwaysValidInConstantExprOpTrait : NativeOpTrait<"AlwaysValidConstantExprOpTrait", [], [{
- ::mlir::LogicalResult CheckValidInConstantExpr() {
- return success();
- }
- }]> {
+def ConstantExprOpTrait : NativeOpTrait<"ConstantExprOpTrait"> {
let cppNamespace = "::mlir::wasmssa";
}
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
index 8086a4828675b..5b2bf1b886f21 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
@@ -70,8 +70,7 @@ def WasmSSA_BranchIfOp : WasmSSA_Op<"branch_if", [
def WasmSSA_ConstOp : WasmSSA_Op<"const", [
AllTypesMatch<["value", "result"]>,
- ConstantExprCheckOpInterface,
- AlwaysValidInConstantExprOpTrait]> {
+ ConstantExprOpTrait]> {
let summary = "Operator that represents a constant value";
let arguments = (ins TypedAttrInterface: $value);
let results = (outs WasmSSA_NumericType: $result);
@@ -218,10 +217,10 @@ def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [
"bool": $isMutable)>
];
let hasCustomAssemblyFormat = 1;
- //let assemblyFormat = "$importName `from` $moduleName `as` $sym_name oilist (`mutable` $isMutable | `vis` $sym_visibility) `:` $type attr-dict";
}
-def WasmSSA_GlobalGetOp : WasmSSA_Op<"global_get", [DeclareOpInterfaceMethods<ConstantExprCheckOpInterface>]> {
+def WasmSSA_GlobalGetOp : WasmSSA_Op<"global_get", [DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ ConstantExprOpTrait]> {
let summary = "Returns the value of the global passed as argument.";
let arguments = (ins FlatSymbolRefAttr: $global);
let results = (outs WasmSSA_ValType: $global_val);
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSATypes.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSATypes.td
index f1f273a7c55a3..4d78836cce75e 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSATypes.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSATypes.td
@@ -52,17 +52,18 @@ def WasmSSA_FuncType : TypeAlias<FunctionType>;
def WasmSSA_LimitType : WasmSSA_Type<"Limit", "limit"> {
let summary = "Wasm limit type";
let parameters = (ins "uint32_t":$min,
- "std::optional<uint32_t>":$max);
- let hasCustomAssemblyFormat = 1;
+ OptionalParameter<"std::optional<uint32_t>">:$max);
+ let assemblyFormat = "`[` $min `` `:` ($max^)? `]`";
}
def WasmSSA_LocalRef : WasmSSA_Type<"LocalRef", "local"> {
let summary = "Type of a local variable";
let parameters = (ins WasmSSA_ValType: $elementType);
let assemblyFormat = "`ref` `to` $elementType";
- let builders = [TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{
+ let builders = [TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{
return get(typeParam.getContext(), typeParam);
}]>,];
+
}
def WasmSSA_TableType : WasmSSA_Type<"Table", "tabletype"> {
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
index 88fd5fb540c15..a33b81f5d3285 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
@@ -33,13 +33,9 @@ LogicalResult verifyConstantExpressionInterface(Operation *op) {
Region &initializerRegion = op->getRegion(0);
WalkResult resultState =
initializerRegion.walk([&](Operation *currentOp) -> WalkResult {
- if (isa<ReturnOp>(currentOp))
+ if (isa<ReturnOp>(currentOp) ||
+ currentOp->hasTrait<ConstantExprOpTrait>())
return WalkResult::advance();
- if (auto interfaceOp =
- dyn_cast<ConstantExprCheckOpInterface>(currentOp)) {
- if (interfaceOp.CheckValidInConstantExpr().succeeded())
- return WalkResult::advance();
- }
op->emitError("expected a constant initializer for this operator, got ")
<< currentOp;
return WalkResult::interrupt();
@@ -48,8 +44,8 @@ LogicalResult verifyConstantExpressionInterface(Operation *op) {
}
LogicalResult verifyLabelLevelInterface(Operation *op) {
- Block* target = cast<LabelLevelOpInterface>(op).getLabelTarget();
- Region* targetRegion = target->getParent();
+ Block *target = cast<LabelLevelOpInterface>(op).getLabelTarget();
+ Region *targetRegion = target->getParent();
if (targetRegion != op->getParentRegion() ||
targetRegion->getParentOp() != op)
return op->emitError("target should be a block defined in same level than "
@@ -60,7 +56,7 @@ LogicalResult verifyLabelLevelInterface(Operation *op) {
llvm::FailureOr<LabelLevelOpInterface>
LabelBranchingOpInterface::getTargetOpFromBlock(::mlir::Block *block,
- uint32_t breakLevel) {
+ uint32_t breakLevel) {
LabelLevelOpInterface res{};
for (size_t curLevel{0}; curLevel <= breakLevel; curLevel++) {
res = dyn_cast_or_null<LabelLevelOpInterface>(block->getParentOp());
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
index fd5a4e9ae5294..41874026b0450 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
@@ -258,19 +258,16 @@ void GlobalOp::print(OpAsmPrinter &printer) {
// GlobalGetOp
//===----------------------------------------------------------------------===//
-// Custom interface overrides
-LogicalResult GlobalGetOp::CheckValidInConstantExpr() {
+LogicalResult
+GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // If the parent requires a constant context, verify that global.get is a constant
+ // as defined per the wasm standard.
+ if(!this->getOperation()->getParentWithTrait<ConstantExpressionInitializerOpTrait>())
+ return success();
+ Operation *symTabOp = SymbolTable::getNearestSymbolTable(*this);
StringRef referencedSymbol = getGlobal();
- Operation *symTableOp =
- getOperation()->getParentWithTrait<OpTrait::SymbolTable>();
- if (!symTableOp)
- return emitError(
- "cannot find the symbol table associated with this operation");
- // NOTE: Having to lookup the symbol inside the symbol table anytime the verifier
- // is called can be costly. This may be improved with caching or another architecture
- // for the constant checking mechanism.
- Operation *definitionOp =
- SymbolTable::lookupSymbolIn(symTableOp, referencedSymbol);
+ Operation *definitionOp = symbolTable.lookupSymbolIn(
+ symTabOp, StringAttr::get(this->getContext(), referencedSymbol));
if (!definitionOp)
return emitError() << "symbol @" << referencedSymbol << " is undefined";
auto definitionImport = llvm::dyn_cast<GlobalImportOp>(definitionOp);
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp
index db8780a2cb513..bee8c8167248d 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp
@@ -16,29 +16,3 @@
namespace mlir::wasmssa {
#include "mlir/Dialect/WasmSSA/IR/WasmSSATypeConstraints.cpp.inc"
} // namespace mlir::wasmssa
-
-using namespace mlir;
-using namespace mlir::wasmssa;
-
-Type LimitType::parse(::mlir::AsmParser &parser) {
- auto res = parser.parseLSquare();
- uint32_t minLimit{0};
- std::optional<uint32_t> maxLimit{std::nullopt};
- res = parser.parseInteger(minLimit);
- res = parser.parseColon();
- uint32_t maxValue{0};
- auto maxParseRes = parser.parseOptionalInteger(maxValue);
- if (maxParseRes.has_value() && (*maxParseRes).succeeded())
- maxLimit = maxValue;
-
- res = parser.parseRSquare();
- return LimitType::get(parser.getContext(), minLimit, maxLimit);
-}
-
-void LimitType::print(AsmPrinter &printer) const {
- printer << '[' << getMin() << ':';
- std::optional<uint32_t> maxLim = getMax();
- if (maxLim)
- printer << *maxLim;
- printer << ']';
-}
diff --git a/mlir/test/Dialect/WasmSSA/global-illegal.mlir b/mlir/test/Dialect/WasmSSA/global-illegal.mlir
index f376a2f410488..b9cafd8b900bf 100644
--- a/mlir/test/Dialect/WasmSSA/global-illegal.mlir
+++ b/mlir/test/Dialect/WasmSSA/global-illegal.mlir
@@ -14,7 +14,6 @@ module {
module {
wasmssa.import_global "glob" from "my_module" as @global_0 mutable nested : i32
- // expected-error at +1 {{expected a constant initializer for this operator}}
wasmssa.global @global_1 i32 : {
// expected-error at +1 {{global.get op is considered constant if it's referring to a import.global symbol marked non-mutable}}
%0 = wasmssa.global_get @global_0 : i32
@@ -25,7 +24,6 @@ module {
// -----
module {
- // expected-error at +1 {{expected a constant initializer for this operator}}
wasmssa.global @global_1 i32 : {
// expected-error at +1 {{symbol @glarble is undefined}}
%0 = wasmssa.global_get @glarble : i32
More information about the Mlir-commits
mailing list