[Mlir-commits] [mlir] [MLIR][Python] Add python bindings for IRDL dialect (PR #158488)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 16 19:55:38 PDT 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/158488

>From 8cbbb6b7a3b21775498cf12bf64ebd7295d4cdf8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 14 Sep 2025 22:16:38 +0800
Subject: [PATCH 1/5] [MLIR][Python] Add python bindings for IRDL dialect

---
 .../mlir/Dialect/IRDL/IR/IRDLAttributes.td    |  2 +-
 mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td  |  8 ++--
 .../include/mlir/Dialect/IRDL/IR/IRDLTypes.td |  2 +-
 mlir/lib/Bindings/Python/DialectIRDL.cpp      | 36 +++++++++++++++
 mlir/python/CMakeLists.txt                    | 23 ++++++++++
 mlir/python/mlir/dialects/IRDLOps.td          | 14 ++++++
 mlir/python/mlir/dialects/irdl.py             | 43 ++++++++++++++++++
 mlir/test/python/dialects/irdl.py             | 45 +++++++++++++++++++
 8 files changed, 167 insertions(+), 6 deletions(-)
 create mode 100644 mlir/lib/Bindings/Python/DialectIRDL.cpp
 create mode 100644 mlir/python/mlir/dialects/IRDLOps.td
 create mode 100644 mlir/python/mlir/dialects/irdl.py
 create mode 100644 mlir/test/python/dialects/irdl.py

diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
index 1bfa7f5cb894b..2f568e8b6c42a 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
@@ -13,7 +13,7 @@
 #ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
 #define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
 
-include "IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/EnumAttr.td"
 
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
index 4a83eb62fba32..3b6b09973645c 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
@@ -13,10 +13,10 @@
 #ifndef MLIR_DIALECT_IRDL_IR_IRDLOPS
 #define MLIR_DIALECT_IRDL_IR_IRDLOPS
 
-include "IRDL.td"
-include "IRDLAttributes.td"
-include "IRDLTypes.td"
-include "IRDLInterfaces.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDLAttributes.td"
+include "mlir/Dialect/IRDL/IR/IRDLTypes.td"
+include "mlir/Dialect/IRDL/IR/IRDLInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
index 9b17bf23df182..9cde433cf33a6 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
@@ -14,7 +14,7 @@
 #define MLIR_DIALECT_IRDL_IR_IRDLTYPES
 
 include "mlir/IR/AttrTypeBase.td"
-include "IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
 
 class IRDL_Type<string name, string typeMnemonic, list<Trait> traits = []>
     : TypeDef<IRDL_Dialect, name, traits> {
diff --git a/mlir/lib/Bindings/Python/DialectIRDL.cpp b/mlir/lib/Bindings/Python/DialectIRDL.cpp
new file mode 100644
index 0000000000000..8264d21d4fa03
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectIRDL.cpp
@@ -0,0 +1,36 @@
+//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/IRDL.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
+namespace nb = nanobind;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+static void populateDialectIRDLSubmodule(nb::module_ &m) {
+  m.def(
+      "load_dialects",
+      [](MlirModule module) {
+        if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module)))
+          throw std::runtime_error(
+              "failed to load IRDL dialects from the input module");
+      },
+      nb::arg("module"), "Load IRDL dialects from the given module.");
+}
+
+NB_MODULE(_mlirDialectsIRDL, m) {
+  m.doc() = "MLIR IRDL dialect.";
+
+  populateDialectIRDLSubmodule(m);
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index c983914722ce1..7b2e1b8c36f25 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -470,6 +470,15 @@ declare_mlir_dialect_python_bindings(
   GEN_ENUM_BINDINGS_TD_FILE
     "dialects/VectorAttributes.td")
 
+declare_mlir_dialect_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/IRDLOps.td
+  SOURCES dialects/irdl.py
+  DIALECT_NAME irdl
+  GEN_ENUM_BINDINGS
+)
+
 ################################################################################
 # Python extensions.
 # The sources for these are all in lib/Bindings/Python, but since they have to
@@ -645,6 +654,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
     MLIRCAPITransformDialect
 )
 
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
+  MODULE_NAME _mlirDialectsIRDL
+  ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
+  ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
+  SOURCES
+    DialectIRDL.cpp
+  PRIVATE_LINK_LIBS
+    LLVMSupport
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPIIR
+    MLIRCAPIIRDL
+)
+
 declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
   MODULE_NAME _mlirAsyncPasses
   ADD_TO_PARENT MLIRPythonSources.Dialects.async
diff --git a/mlir/python/mlir/dialects/IRDLOps.td b/mlir/python/mlir/dialects/IRDLOps.td
new file mode 100644
index 0000000000000..695839f1aa08b
--- /dev/null
+++ b/mlir/python/mlir/dialects/IRDLOps.td
@@ -0,0 +1,14 @@
+//===-- IRDLOps.td - Entry point for IRDL bind ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_IRDL_OPS
+#define PYTHON_BINDINGS_IRDL_OPS
+
+include "mlir/Dialect/IRDL/IR/IRDLOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py
new file mode 100644
index 0000000000000..2314ee99950e0
--- /dev/null
+++ b/mlir/python/mlir/dialects/irdl.py
@@ -0,0 +1,43 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._irdl_ops_gen import *
+from ._irdl_ops_gen import _Dialect
+from ._irdl_enum_gen import *
+from .._mlir_libs._mlirDialectsIRDL import *
+from ..ir import register_attribute_builder
+from ._ods_common import (
+    get_op_result_or_value as _get_value,
+    get_op_results_or_values as _get_values,
+    _cext as _ods_cext,
+)
+from ..extras.meta import region_op
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class DialectOp(DialectOp):
+    """Specialization for the dialect op class."""
+
+    def __init__(self, sym_name, *, loc=None, ip=None):
+        super().__init__(sym_name, loc=loc, ip=ip)
+        self.regions[0].blocks.append()
+
+    @property
+    def body(self):
+        return self.regions[0].blocks[0]
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperationOp(OperationOp):
+    """Specialization for the operation op class."""
+
+    def __init__(self, sym_name, *, loc=None, ip=None):
+        super().__init__(sym_name, loc=loc, ip=ip)
+        self.regions[0].blocks.append()
+
+    @property
+    def body(self):
+        return self.regions[0].blocks[0]
+
+ at register_attribute_builder("VariadicityArrayAttr")
+def _variadicity_array_attr(x, context):
+    return _ods_cext.ir.Attribute.parse(f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>")
diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
new file mode 100644
index 0000000000000..30983af302a52
--- /dev/null
+++ b/mlir/test/python/dialects/irdl.py
@@ -0,0 +1,45 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import irdl
+import sys
+
+
+def run(f):
+    print("\nTEST:", f.__name__, file=sys.stderr)
+    f()
+
+
+# CHECK: TEST: testIRDL
+ at run
+def testIRDL():
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            dialect = irdl.DialectOp("irdl_test")
+            with InsertionPoint(dialect.body):
+                op = irdl.OperationOp("test_op")
+                with InsertionPoint(op.body):
+                    f32 = irdl.is_(TypeAttr.get(F32Type.get()))
+                    irdl.operands_([f32], ["input"], [irdl.Variadicity.single])
+
+        # CHECK: module {
+        # CHECK:   irdl.dialect @irdl_test {
+        # CHECK:     irdl.operation @test_op {
+        # CHECK:       %0 = irdl.is f32
+        # CHECK:       irdl.operands(input: %0)
+        # CHECK:     }
+        # CHECK:   }
+        # CHECK: }
+        module.dump()
+
+        irdl.load_dialects(module)
+
+        m = Module.parse("""
+          module {
+            %a = arith.constant 1.0 : f32
+            "irdl_test.test_op"(%a) : (f32) -> ()
+          }
+        """)
+        # CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
+        m.dump()

>From 99375299298093d26b020e86e69b9715f16b26ea Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 14 Sep 2025 22:43:44 +0800
Subject: [PATCH 2/5] format

---
 mlir/python/mlir/dialects/irdl.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py
index 2314ee99950e0..743870162a8bb 100644
--- a/mlir/python/mlir/dialects/irdl.py
+++ b/mlir/python/mlir/dialects/irdl.py
@@ -14,6 +14,7 @@
 )
 from ..extras.meta import region_op
 
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class DialectOp(DialectOp):
     """Specialization for the dialect op class."""
@@ -26,6 +27,7 @@ def __init__(self, sym_name, *, loc=None, ip=None):
     def body(self):
         return self.regions[0].blocks[0]
 
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class OperationOp(OperationOp):
     """Specialization for the operation op class."""
@@ -38,6 +40,9 @@ def __init__(self, sym_name, *, loc=None, ip=None):
     def body(self):
         return self.regions[0].blocks[0]
 
+
 @register_attribute_builder("VariadicityArrayAttr")
 def _variadicity_array_attr(x, context):
-    return _ods_cext.ir.Attribute.parse(f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>")
+    return _ods_cext.ir.Attribute.parse(
+        f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>"
+    )

>From 5bafd4247920ec9db95468c130d3cba885e50ca1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 14 Sep 2025 22:50:03 +0800
Subject: [PATCH 3/5] format

---
 mlir/test/python/dialects/irdl.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
index 30983af302a52..0c4007b4d2a95 100644
--- a/mlir/test/python/dialects/irdl.py
+++ b/mlir/test/python/dialects/irdl.py
@@ -35,11 +35,13 @@ def testIRDL():
 
         irdl.load_dialects(module)
 
-        m = Module.parse("""
+        m = Module.parse(
+            """
           module {
             %a = arith.constant 1.0 : f32
             "irdl_test.test_op"(%a) : (f32) -> ()
           }
-        """)
+        """
+        )
         # CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
         m.dump()

>From fc7728d98cac3a3c155680a0030b5e18a00a07f3 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 14 Sep 2025 22:56:46 +0800
Subject: [PATCH 4/5] refine test case

---
 mlir/test/python/dialects/irdl.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
index 0c4007b4d2a95..307a3f90af8b6 100644
--- a/mlir/test/python/dialects/irdl.py
+++ b/mlir/test/python/dialects/irdl.py
@@ -43,5 +43,7 @@ def testIRDL():
           }
         """
         )
-        # CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
+        # CHECK: module {
+        # CHECK:   "irdl_test.test_op"(%cst) : (f32) -> ()
+        # CHECK: }
         m.dump()

>From df14b403ad05b9dea48e6206dad5df0a5a75339f Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 17 Sep 2025 10:54:43 +0800
Subject: [PATCH 5/5] add class for typeop and attributeop

---
 mlir/python/mlir/dialects/irdl.py | 26 ++++++++++++++++++++++++++
 mlir/test/python/dialects/irdl.py | 17 +++++++++++++++++
 2 files changed, 43 insertions(+)

diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py
index 743870162a8bb..8fb2e21bbe0d3 100644
--- a/mlir/python/mlir/dialects/irdl.py
+++ b/mlir/python/mlir/dialects/irdl.py
@@ -41,6 +41,32 @@ def body(self):
         return self.regions[0].blocks[0]
 
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TypeOp(TypeOp):
+    """Specialization for the type op class."""
+
+    def __init__(self, sym_name, *, loc=None, ip=None):
+        super().__init__(sym_name, loc=loc, ip=ip)
+        self.regions[0].blocks.append()
+
+    @property
+    def body(self):
+        return self.regions[0].blocks[0]
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AttributeOp(AttributeOp):
+    """Specialization for the attribute op class."""
+
+    def __init__(self, sym_name, *, loc=None, ip=None):
+        super().__init__(sym_name, loc=loc, ip=ip)
+        self.regions[0].blocks.append()
+
+    @property
+    def body(self):
+        return self.regions[0].blocks[0]
+
+
 @register_attribute_builder("VariadicityArrayAttr")
 def _variadicity_array_attr(x, context):
     return _ods_cext.ir.Attribute.parse(
diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
index 307a3f90af8b6..e01c5be238976 100644
--- a/mlir/test/python/dialects/irdl.py
+++ b/mlir/test/python/dialects/irdl.py
@@ -22,6 +22,14 @@ def testIRDL():
                 with InsertionPoint(op.body):
                     f32 = irdl.is_(TypeAttr.get(F32Type.get()))
                     irdl.operands_([f32], ["input"], [irdl.Variadicity.single])
+                type1 = irdl.TypeOp("type1")
+                with InsertionPoint(type1.body):
+                    f32 = irdl.is_(TypeAttr.get(F32Type.get()))
+                    irdl.parameters([f32], ["val"])
+                attr1 = irdl.AttributeOp("attr1")
+                with InsertionPoint(attr1.body):
+                    test = irdl.is_(StringAttr.get("test"))
+                    irdl.parameters([test], ["val"])
 
         # CHECK: module {
         # CHECK:   irdl.dialect @irdl_test {
@@ -29,8 +37,17 @@ def testIRDL():
         # CHECK:       %0 = irdl.is f32
         # CHECK:       irdl.operands(input: %0)
         # CHECK:     }
+        # CHECK:     irdl.type @type1 {
+        # CHECK:       %0 = irdl.is f32
+        # CHECK:       irdl.parameters(val: %0)
+        # CHECK:     }
+        # CHECK:     irdl.attribute @attr1 {
+        # CHECK:       %0 = irdl.is "test"
+        # CHECK:       irdl.parameters(val: %0)
+        # CHECK:     }
         # CHECK:   }
         # CHECK: }
+        module.operation.verify()
         module.dump()
 
         irdl.load_dialects(module)



More information about the Mlir-commits mailing list