[Mlir-commits] [mlir] e5114a2 - [MLIR][Python] Add python bindings for IRDL dialect (#158488)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 18 19:10:44 PDT 2025
Author: Twice
Date: 2025-09-19T10:10:39+08:00
New Revision: e5114a2016557fc0dd6014b838e91ca025e23b29
URL: https://github.com/llvm/llvm-project/commit/e5114a2016557fc0dd6014b838e91ca025e23b29
DIFF: https://github.com/llvm/llvm-project/commit/e5114a2016557fc0dd6014b838e91ca025e23b29.diff
LOG: [MLIR][Python] Add python bindings for IRDL dialect (#158488)
In this PR we add basic python bindings for IRDL dialect, so that python
users can create and load IRDL dialects in python. This allows users, to
some extent, to define dialects in Python without having to modify
MLIR’s CMake/TableGen/C++ code and rebuild, making prototyping more
convenient.
A basic example is shown below (and also in the added test case):
```python
# create a module with IRDL dialects
module = Module.create()
with InsertionPoint(module.body):
dialect = irdl.DialectOp("irdl_test")
with InsertionPoint(dialect.body):
op = irdl.OperationOp("test_op")
with InsertionPoint(op.body):
f32 = irdl.is_(TypeAttr.get(F32Type.get()))
irdl.operands_([f32], ["input"], [irdl.Variadicity.single])
# load the module
irdl.load_dialects(module)
# use the op defined in IRDL
m = Module.parse("""
module {
%a = arith.constant 1.0 : f32
"irdl_test.test_op"(%a) : (f32) -> ()
}
""")
```
Added:
mlir/lib/Bindings/Python/DialectIRDL.cpp
mlir/python/mlir/dialects/IRDLOps.td
mlir/python/mlir/dialects/irdl.py
mlir/test/python/dialects/irdl.py
Modified:
mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
mlir/python/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
index 1bfa7f5cb894b..2f568e8b6c42a 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
@@ -13,7 +13,7 @@
#ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
#define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
-include "IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
index 4a83eb62fba32..3b6b09973645c 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
@@ -13,10 +13,10 @@
#ifndef MLIR_DIALECT_IRDL_IR_IRDLOPS
#define MLIR_DIALECT_IRDL_IR_IRDLOPS
-include "IRDL.td"
-include "IRDLAttributes.td"
-include "IRDLTypes.td"
-include "IRDLInterfaces.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDLAttributes.td"
+include "mlir/Dialect/IRDL/IR/IRDLTypes.td"
+include "mlir/Dialect/IRDL/IR/IRDLInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/SymbolInterfaces.td"
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
index 9b17bf23df182..9cde433cf33a6 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
@@ -14,7 +14,7 @@
#define MLIR_DIALECT_IRDL_IR_IRDLTYPES
include "mlir/IR/AttrTypeBase.td"
-include "IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
class IRDL_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<IRDL_Dialect, name, traits> {
diff --git a/mlir/lib/Bindings/Python/DialectIRDL.cpp b/mlir/lib/Bindings/Python/DialectIRDL.cpp
new file mode 100644
index 0000000000000..08bcab97c03ec
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectIRDL.cpp
@@ -0,0 +1,35 @@
+//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/IRDL.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
+namespace nb = nanobind;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+static void populateDialectIRDLSubmodule(nb::module_ &m) {
+ m.def(
+ "load_dialects",
+ [](MlirModule module) {
+ if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module)))
+ throw std::runtime_error(
+ "failed to load IRDL dialects from the input module");
+ },
+ nb::arg("module"), "Load IRDL dialects from the given module.");
+}
+
+NB_MODULE(_mlirDialectsIRDL, m) {
+ m.doc() = "MLIR IRDL dialect.";
+
+ populateDialectIRDLSubmodule(m);
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index c983914722ce1..7b2e1b8c36f25 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -470,6 +470,15 @@ declare_mlir_dialect_python_bindings(
GEN_ENUM_BINDINGS_TD_FILE
"dialects/VectorAttributes.td")
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/IRDLOps.td
+ SOURCES dialects/irdl.py
+ DIALECT_NAME irdl
+ GEN_ENUM_BINDINGS
+)
+
################################################################################
# Python extensions.
# The sources for these are all in lib/Bindings/Python, but since they have to
@@ -645,6 +654,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
MLIRCAPITransformDialect
)
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
+ MODULE_NAME _mlirDialectsIRDL
+ ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
+ SOURCES
+ DialectIRDL.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPIIRDL
+)
+
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
MODULE_NAME _mlirAsyncPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.async
diff --git a/mlir/python/mlir/dialects/IRDLOps.td b/mlir/python/mlir/dialects/IRDLOps.td
new file mode 100644
index 0000000000000..7b061fcf30836
--- /dev/null
+++ b/mlir/python/mlir/dialects/IRDLOps.td
@@ -0,0 +1,14 @@
+//===-- IRDLOps.td - Entry point for IRDL binding ----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_IRDL_OPS
+#define PYTHON_BINDINGS_IRDL_OPS
+
+include "mlir/Dialect/IRDL/IR/IRDLOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py
new file mode 100644
index 0000000000000..1ec951b69b646
--- /dev/null
+++ b/mlir/python/mlir/dialects/irdl.py
@@ -0,0 +1,92 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._irdl_ops_gen import *
+from ._irdl_ops_gen import _Dialect
+from ._irdl_enum_gen import *
+from .._mlir_libs._mlirDialectsIRDL import *
+from ..ir import register_attribute_builder
+from ._ods_common import _cext as _ods_cext
+from typing import Union, Sequence
+
+_ods_ir = _ods_cext.ir
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class DialectOp(DialectOp):
+ __doc__ = DialectOp.__doc__
+
+ def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
+ super().__init__(sym_name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self) -> _ods_ir.Block:
+ return self.regions[0].blocks[0]
+
+
+def dialect(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> DialectOp:
+ return DialectOp(sym_name=sym_name, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperationOp(OperationOp):
+ __doc__ = OperationOp.__doc__
+
+ def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
+ super().__init__(sym_name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self) -> _ods_ir.Block:
+ return self.regions[0].blocks[0]
+
+
+def operation_(
+ sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None
+) -> OperationOp:
+ return OperationOp(sym_name=sym_name, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TypeOp(TypeOp):
+ __doc__ = TypeOp.__doc__
+
+ def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
+ super().__init__(sym_name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self) -> _ods_ir.Block:
+ return self.regions[0].blocks[0]
+
+
+def type_(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> TypeOp:
+ return TypeOp(sym_name=sym_name, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AttributeOp(AttributeOp):
+ __doc__ = AttributeOp.__doc__
+
+ def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
+ super().__init__(sym_name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self) -> _ods_ir.Block:
+ return self.regions[0].blocks[0]
+
+
+def attribute(
+ sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None
+) -> AttributeOp:
+ return AttributeOp(sym_name=sym_name, loc=loc, ip=ip)
+
+
+ at register_attribute_builder("VariadicityArrayAttr")
+def _variadicity_array_attr(x: Sequence[Variadicity], context) -> _ods_ir.Attribute:
+ return _ods_ir.Attribute.parse(
+ f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>", context
+ )
diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
new file mode 100644
index 0000000000000..ed62db9b69968
--- /dev/null
+++ b/mlir/test/python/dialects/irdl.py
@@ -0,0 +1,66 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects.irdl import *
+import sys
+
+
+def run(f):
+ print("\nTEST:", f.__name__, file=sys.stderr)
+ f()
+
+
+# CHECK: TEST: testIRDL
+ at run
+def testIRDL():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ irdl_test = dialect("irdl_test")
+ with InsertionPoint(irdl_test.body):
+ op = operation_("test_op")
+ with InsertionPoint(op.body):
+ f32 = is_(TypeAttr.get(F32Type.get()))
+ operands_([f32], ["input"], [Variadicity.single])
+ type1 = type_("type1")
+ with InsertionPoint(type1.body):
+ f32 = is_(TypeAttr.get(F32Type.get()))
+ parameters([f32], ["val"])
+ attr1 = attribute("attr1")
+ with InsertionPoint(attr1.body):
+ test = is_(StringAttr.get("test"))
+ parameters([test], ["val"])
+
+ # CHECK: module {
+ # CHECK: irdl.dialect @irdl_test {
+ # CHECK: irdl.operation @test_op {
+ # CHECK: %0 = irdl.is f32
+ # CHECK: irdl.operands(input: %0)
+ # CHECK: }
+ # CHECK: irdl.type @type1 {
+ # CHECK: %0 = irdl.is f32
+ # CHECK: irdl.parameters(val: %0)
+ # CHECK: }
+ # CHECK: irdl.attribute @attr1 {
+ # CHECK: %0 = irdl.is "test"
+ # CHECK: irdl.parameters(val: %0)
+ # CHECK: }
+ # CHECK: }
+ # CHECK: }
+ module.operation.verify()
+ module.dump()
+
+ load_dialects(module)
+
+ m = Module.parse(
+ """
+ module {
+ %a = arith.constant 1.0 : f32
+ "irdl_test.test_op"(%a) : (f32) -> ()
+ }
+ """
+ )
+ # CHECK: module {
+ # CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
+ # CHECK: }
+ m.dump()
More information about the Mlir-commits
mailing list