[Mlir-commits] [mlir] [MLIR][Python] Add python bindings for IRDL dialect (PR #158488)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 16 19:55:38 PDT 2025
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/158488
>From 8cbbb6b7a3b21775498cf12bf64ebd7295d4cdf8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 14 Sep 2025 22:16:38 +0800
Subject: [PATCH 1/5] [MLIR][Python] Add python bindings for IRDL dialect
---
.../mlir/Dialect/IRDL/IR/IRDLAttributes.td | 2 +-
mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td | 8 ++--
.../include/mlir/Dialect/IRDL/IR/IRDLTypes.td | 2 +-
mlir/lib/Bindings/Python/DialectIRDL.cpp | 36 +++++++++++++++
mlir/python/CMakeLists.txt | 23 ++++++++++
mlir/python/mlir/dialects/IRDLOps.td | 14 ++++++
mlir/python/mlir/dialects/irdl.py | 43 ++++++++++++++++++
mlir/test/python/dialects/irdl.py | 45 +++++++++++++++++++
8 files changed, 167 insertions(+), 6 deletions(-)
create mode 100644 mlir/lib/Bindings/Python/DialectIRDL.cpp
create mode 100644 mlir/python/mlir/dialects/IRDLOps.td
create mode 100644 mlir/python/mlir/dialects/irdl.py
create mode 100644 mlir/test/python/dialects/irdl.py
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
index 1bfa7f5cb894b..2f568e8b6c42a 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
@@ -13,7 +13,7 @@
#ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
#define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
-include "IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
index 4a83eb62fba32..3b6b09973645c 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
@@ -13,10 +13,10 @@
#ifndef MLIR_DIALECT_IRDL_IR_IRDLOPS
#define MLIR_DIALECT_IRDL_IR_IRDLOPS
-include "IRDL.td"
-include "IRDLAttributes.td"
-include "IRDLTypes.td"
-include "IRDLInterfaces.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDLAttributes.td"
+include "mlir/Dialect/IRDL/IR/IRDLTypes.td"
+include "mlir/Dialect/IRDL/IR/IRDLInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/SymbolInterfaces.td"
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
index 9b17bf23df182..9cde433cf33a6 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
@@ -14,7 +14,7 @@
#define MLIR_DIALECT_IRDL_IR_IRDLTYPES
include "mlir/IR/AttrTypeBase.td"
-include "IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
class IRDL_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<IRDL_Dialect, name, traits> {
diff --git a/mlir/lib/Bindings/Python/DialectIRDL.cpp b/mlir/lib/Bindings/Python/DialectIRDL.cpp
new file mode 100644
index 0000000000000..8264d21d4fa03
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectIRDL.cpp
@@ -0,0 +1,36 @@
+//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/IRDL.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
+namespace nb = nanobind;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+static void populateDialectIRDLSubmodule(nb::module_ &m) {
+ m.def(
+ "load_dialects",
+ [](MlirModule module) {
+ if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module)))
+ throw std::runtime_error(
+ "failed to load IRDL dialects from the input module");
+ },
+ nb::arg("module"), "Load IRDL dialects from the given module.");
+}
+
+NB_MODULE(_mlirDialectsIRDL, m) {
+ m.doc() = "MLIR IRDL dialect.";
+
+ populateDialectIRDLSubmodule(m);
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index c983914722ce1..7b2e1b8c36f25 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -470,6 +470,15 @@ declare_mlir_dialect_python_bindings(
GEN_ENUM_BINDINGS_TD_FILE
"dialects/VectorAttributes.td")
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/IRDLOps.td
+ SOURCES dialects/irdl.py
+ DIALECT_NAME irdl
+ GEN_ENUM_BINDINGS
+)
+
################################################################################
# Python extensions.
# The sources for these are all in lib/Bindings/Python, but since they have to
@@ -645,6 +654,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
MLIRCAPITransformDialect
)
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
+ MODULE_NAME _mlirDialectsIRDL
+ ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
+ SOURCES
+ DialectIRDL.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPIIRDL
+)
+
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
MODULE_NAME _mlirAsyncPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.async
diff --git a/mlir/python/mlir/dialects/IRDLOps.td b/mlir/python/mlir/dialects/IRDLOps.td
new file mode 100644
index 0000000000000..695839f1aa08b
--- /dev/null
+++ b/mlir/python/mlir/dialects/IRDLOps.td
@@ -0,0 +1,14 @@
+//===-- IRDLOps.td - Entry point for IRDL bind ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_IRDL_OPS
+#define PYTHON_BINDINGS_IRDL_OPS
+
+include "mlir/Dialect/IRDL/IR/IRDLOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py
new file mode 100644
index 0000000000000..2314ee99950e0
--- /dev/null
+++ b/mlir/python/mlir/dialects/irdl.py
@@ -0,0 +1,43 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._irdl_ops_gen import *
+from ._irdl_ops_gen import _Dialect
+from ._irdl_enum_gen import *
+from .._mlir_libs._mlirDialectsIRDL import *
+from ..ir import register_attribute_builder
+from ._ods_common import (
+ get_op_result_or_value as _get_value,
+ get_op_results_or_values as _get_values,
+ _cext as _ods_cext,
+)
+from ..extras.meta import region_op
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class DialectOp(DialectOp):
+ """Specialization for the dialect op class."""
+
+ def __init__(self, sym_name, *, loc=None, ip=None):
+ super().__init__(sym_name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self):
+ return self.regions[0].blocks[0]
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperationOp(OperationOp):
+ """Specialization for the operation op class."""
+
+ def __init__(self, sym_name, *, loc=None, ip=None):
+ super().__init__(sym_name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self):
+ return self.regions[0].blocks[0]
+
+ at register_attribute_builder("VariadicityArrayAttr")
+def _variadicity_array_attr(x, context):
+ return _ods_cext.ir.Attribute.parse(f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>")
diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
new file mode 100644
index 0000000000000..30983af302a52
--- /dev/null
+++ b/mlir/test/python/dialects/irdl.py
@@ -0,0 +1,45 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import irdl
+import sys
+
+
+def run(f):
+ print("\nTEST:", f.__name__, file=sys.stderr)
+ f()
+
+
+# CHECK: TEST: testIRDL
+ at run
+def testIRDL():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ dialect = irdl.DialectOp("irdl_test")
+ with InsertionPoint(dialect.body):
+ op = irdl.OperationOp("test_op")
+ with InsertionPoint(op.body):
+ f32 = irdl.is_(TypeAttr.get(F32Type.get()))
+ irdl.operands_([f32], ["input"], [irdl.Variadicity.single])
+
+ # CHECK: module {
+ # CHECK: irdl.dialect @irdl_test {
+ # CHECK: irdl.operation @test_op {
+ # CHECK: %0 = irdl.is f32
+ # CHECK: irdl.operands(input: %0)
+ # CHECK: }
+ # CHECK: }
+ # CHECK: }
+ module.dump()
+
+ irdl.load_dialects(module)
+
+ m = Module.parse("""
+ module {
+ %a = arith.constant 1.0 : f32
+ "irdl_test.test_op"(%a) : (f32) -> ()
+ }
+ """)
+ # CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
+ m.dump()
>From 99375299298093d26b020e86e69b9715f16b26ea Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 14 Sep 2025 22:43:44 +0800
Subject: [PATCH 2/5] format
---
mlir/python/mlir/dialects/irdl.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py
index 2314ee99950e0..743870162a8bb 100644
--- a/mlir/python/mlir/dialects/irdl.py
+++ b/mlir/python/mlir/dialects/irdl.py
@@ -14,6 +14,7 @@
)
from ..extras.meta import region_op
+
@_ods_cext.register_operation(_Dialect, replace=True)
class DialectOp(DialectOp):
"""Specialization for the dialect op class."""
@@ -26,6 +27,7 @@ def __init__(self, sym_name, *, loc=None, ip=None):
def body(self):
return self.regions[0].blocks[0]
+
@_ods_cext.register_operation(_Dialect, replace=True)
class OperationOp(OperationOp):
"""Specialization for the operation op class."""
@@ -38,6 +40,9 @@ def __init__(self, sym_name, *, loc=None, ip=None):
def body(self):
return self.regions[0].blocks[0]
+
@register_attribute_builder("VariadicityArrayAttr")
def _variadicity_array_attr(x, context):
- return _ods_cext.ir.Attribute.parse(f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>")
+ return _ods_cext.ir.Attribute.parse(
+ f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>"
+ )
>From 5bafd4247920ec9db95468c130d3cba885e50ca1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 14 Sep 2025 22:50:03 +0800
Subject: [PATCH 3/5] format
---
mlir/test/python/dialects/irdl.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
index 30983af302a52..0c4007b4d2a95 100644
--- a/mlir/test/python/dialects/irdl.py
+++ b/mlir/test/python/dialects/irdl.py
@@ -35,11 +35,13 @@ def testIRDL():
irdl.load_dialects(module)
- m = Module.parse("""
+ m = Module.parse(
+ """
module {
%a = arith.constant 1.0 : f32
"irdl_test.test_op"(%a) : (f32) -> ()
}
- """)
+ """
+ )
# CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
m.dump()
>From fc7728d98cac3a3c155680a0030b5e18a00a07f3 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 14 Sep 2025 22:56:46 +0800
Subject: [PATCH 4/5] refine test case
---
mlir/test/python/dialects/irdl.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
index 0c4007b4d2a95..307a3f90af8b6 100644
--- a/mlir/test/python/dialects/irdl.py
+++ b/mlir/test/python/dialects/irdl.py
@@ -43,5 +43,7 @@ def testIRDL():
}
"""
)
- # CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
+ # CHECK: module {
+ # CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
+ # CHECK: }
m.dump()
>From df14b403ad05b9dea48e6206dad5df0a5a75339f Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 17 Sep 2025 10:54:43 +0800
Subject: [PATCH 5/5] add class for typeop and attributeop
---
mlir/python/mlir/dialects/irdl.py | 26 ++++++++++++++++++++++++++
mlir/test/python/dialects/irdl.py | 17 +++++++++++++++++
2 files changed, 43 insertions(+)
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py
index 743870162a8bb..8fb2e21bbe0d3 100644
--- a/mlir/python/mlir/dialects/irdl.py
+++ b/mlir/python/mlir/dialects/irdl.py
@@ -41,6 +41,32 @@ def body(self):
return self.regions[0].blocks[0]
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TypeOp(TypeOp):
+ """Specialization for the type op class."""
+
+ def __init__(self, sym_name, *, loc=None, ip=None):
+ super().__init__(sym_name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self):
+ return self.regions[0].blocks[0]
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AttributeOp(AttributeOp):
+ """Specialization for the attribute op class."""
+
+ def __init__(self, sym_name, *, loc=None, ip=None):
+ super().__init__(sym_name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self):
+ return self.regions[0].blocks[0]
+
+
@register_attribute_builder("VariadicityArrayAttr")
def _variadicity_array_attr(x, context):
return _ods_cext.ir.Attribute.parse(
diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
index 307a3f90af8b6..e01c5be238976 100644
--- a/mlir/test/python/dialects/irdl.py
+++ b/mlir/test/python/dialects/irdl.py
@@ -22,6 +22,14 @@ def testIRDL():
with InsertionPoint(op.body):
f32 = irdl.is_(TypeAttr.get(F32Type.get()))
irdl.operands_([f32], ["input"], [irdl.Variadicity.single])
+ type1 = irdl.TypeOp("type1")
+ with InsertionPoint(type1.body):
+ f32 = irdl.is_(TypeAttr.get(F32Type.get()))
+ irdl.parameters([f32], ["val"])
+ attr1 = irdl.AttributeOp("attr1")
+ with InsertionPoint(attr1.body):
+ test = irdl.is_(StringAttr.get("test"))
+ irdl.parameters([test], ["val"])
# CHECK: module {
# CHECK: irdl.dialect @irdl_test {
@@ -29,8 +37,17 @@ def testIRDL():
# CHECK: %0 = irdl.is f32
# CHECK: irdl.operands(input: %0)
# CHECK: }
+ # CHECK: irdl.type @type1 {
+ # CHECK: %0 = irdl.is f32
+ # CHECK: irdl.parameters(val: %0)
+ # CHECK: }
+ # CHECK: irdl.attribute @attr1 {
+ # CHECK: %0 = irdl.is "test"
+ # CHECK: irdl.parameters(val: %0)
+ # CHECK: }
# CHECK: }
# CHECK: }
+ module.operation.verify()
module.dump()
irdl.load_dialects(module)
More information about the Mlir-commits
mailing list