[Mlir-commits] [mlir] [MLIR][Python] Add python bindings for IRDL	dialect (PR #158488)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Sun Sep 14 07:27:27 PDT 2025
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-irdl
@llvm/pr-subscribers-mlir
Author: Twice (PragmaTwice)
<details>
<summary>Changes</summary>
In this PR we add basic python bindings for IRDL dialect, so that python users can create and load IRDL dialects in python. This allows users, to some extent, to define dialects in Python without having to modify MLIR’s CMake/TableGen/C++ code and rebuild, making prototyping more convenient.
A basic example is shown below (and also in the added test case):
```python
# create a module with IRDL dialects
module = Module.create()
with InsertionPoint(module.body):
  dialect = irdl.DialectOp("irdl_test")
  with InsertionPoint(dialect.body):
    op = irdl.OperationOp("test_op")
    with InsertionPoint(op.body):
      f32 = irdl.is_(TypeAttr.get(F32Type.get()))
      irdl.operands_([f32], ["input"], [irdl.Variadicity.single])
# load the module
irdl.load_dialects(module)
# use the op defined in IRDL
m = Module.parse("""
  module {
    %a = arith.constant 1.0 : f32
    "irdl_test.test_op"(%a) : (f32) -> ()
  }
""")
```
---
Full diff: https://github.com/llvm/llvm-project/pull/158488.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td (+1-1) 
- (modified) mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td (+4-4) 
- (modified) mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td (+1-1) 
- (added) mlir/lib/Bindings/Python/DialectIRDL.cpp (+36) 
- (modified) mlir/python/CMakeLists.txt (+23) 
- (added) mlir/python/mlir/dialects/IRDLOps.td (+14) 
- (added) mlir/python/mlir/dialects/irdl.py (+43) 
- (added) mlir/test/python/dialects/irdl.py (+45) 
``````````diff
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
index 1bfa7f5cb894b..2f568e8b6c42a 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
@@ -13,7 +13,7 @@
 #ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
 #define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
 
-include "IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/EnumAttr.td"
 
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
index 4a83eb62fba32..3b6b09973645c 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
@@ -13,10 +13,10 @@
 #ifndef MLIR_DIALECT_IRDL_IR_IRDLOPS
 #define MLIR_DIALECT_IRDL_IR_IRDLOPS
 
-include "IRDL.td"
-include "IRDLAttributes.td"
-include "IRDLTypes.td"
-include "IRDLInterfaces.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDLAttributes.td"
+include "mlir/Dialect/IRDL/IR/IRDLTypes.td"
+include "mlir/Dialect/IRDL/IR/IRDLInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
index 9b17bf23df182..9cde433cf33a6 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
@@ -14,7 +14,7 @@
 #define MLIR_DIALECT_IRDL_IR_IRDLTYPES
 
 include "mlir/IR/AttrTypeBase.td"
-include "IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
 
 class IRDL_Type<string name, string typeMnemonic, list<Trait> traits = []>
     : TypeDef<IRDL_Dialect, name, traits> {
diff --git a/mlir/lib/Bindings/Python/DialectIRDL.cpp b/mlir/lib/Bindings/Python/DialectIRDL.cpp
new file mode 100644
index 0000000000000..8264d21d4fa03
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectIRDL.cpp
@@ -0,0 +1,36 @@
+//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/IRDL.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
+namespace nb = nanobind;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+static void populateDialectIRDLSubmodule(nb::module_ &m) {
+  m.def(
+      "load_dialects",
+      [](MlirModule module) {
+        if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module)))
+          throw std::runtime_error(
+              "failed to load IRDL dialects from the input module");
+      },
+      nb::arg("module"), "Load IRDL dialects from the given module.");
+}
+
+NB_MODULE(_mlirDialectsIRDL, m) {
+  m.doc() = "MLIR IRDL dialect.";
+
+  populateDialectIRDLSubmodule(m);
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index c983914722ce1..7b2e1b8c36f25 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -470,6 +470,15 @@ declare_mlir_dialect_python_bindings(
   GEN_ENUM_BINDINGS_TD_FILE
     "dialects/VectorAttributes.td")
 
+declare_mlir_dialect_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/IRDLOps.td
+  SOURCES dialects/irdl.py
+  DIALECT_NAME irdl
+  GEN_ENUM_BINDINGS
+)
+
 ################################################################################
 # Python extensions.
 # The sources for these are all in lib/Bindings/Python, but since they have to
@@ -645,6 +654,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
     MLIRCAPITransformDialect
 )
 
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
+  MODULE_NAME _mlirDialectsIRDL
+  ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
+  ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
+  SOURCES
+    DialectIRDL.cpp
+  PRIVATE_LINK_LIBS
+    LLVMSupport
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPIIR
+    MLIRCAPIIRDL
+)
+
 declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
   MODULE_NAME _mlirAsyncPasses
   ADD_TO_PARENT MLIRPythonSources.Dialects.async
diff --git a/mlir/python/mlir/dialects/IRDLOps.td b/mlir/python/mlir/dialects/IRDLOps.td
new file mode 100644
index 0000000000000..695839f1aa08b
--- /dev/null
+++ b/mlir/python/mlir/dialects/IRDLOps.td
@@ -0,0 +1,14 @@
+//===-- IRDLOps.td - Entry point for IRDL bind ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_IRDL_OPS
+#define PYTHON_BINDINGS_IRDL_OPS
+
+include "mlir/Dialect/IRDL/IR/IRDLOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py
new file mode 100644
index 0000000000000..2314ee99950e0
--- /dev/null
+++ b/mlir/python/mlir/dialects/irdl.py
@@ -0,0 +1,43 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._irdl_ops_gen import *
+from ._irdl_ops_gen import _Dialect
+from ._irdl_enum_gen import *
+from .._mlir_libs._mlirDialectsIRDL import *
+from ..ir import register_attribute_builder
+from ._ods_common import (
+    get_op_result_or_value as _get_value,
+    get_op_results_or_values as _get_values,
+    _cext as _ods_cext,
+)
+from ..extras.meta import region_op
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class DialectOp(DialectOp):
+    """Specialization for the dialect op class."""
+
+    def __init__(self, sym_name, *, loc=None, ip=None):
+        super().__init__(sym_name, loc=loc, ip=ip)
+        self.regions[0].blocks.append()
+
+    @property
+    def body(self):
+        return self.regions[0].blocks[0]
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperationOp(OperationOp):
+    """Specialization for the operation op class."""
+
+    def __init__(self, sym_name, *, loc=None, ip=None):
+        super().__init__(sym_name, loc=loc, ip=ip)
+        self.regions[0].blocks.append()
+
+    @property
+    def body(self):
+        return self.regions[0].blocks[0]
+
+ at register_attribute_builder("VariadicityArrayAttr")
+def _variadicity_array_attr(x, context):
+    return _ods_cext.ir.Attribute.parse(f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>")
diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
new file mode 100644
index 0000000000000..30983af302a52
--- /dev/null
+++ b/mlir/test/python/dialects/irdl.py
@@ -0,0 +1,45 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import irdl
+import sys
+
+
+def run(f):
+    print("\nTEST:", f.__name__, file=sys.stderr)
+    f()
+
+
+# CHECK: TEST: testIRDL
+ at run
+def testIRDL():
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            dialect = irdl.DialectOp("irdl_test")
+            with InsertionPoint(dialect.body):
+                op = irdl.OperationOp("test_op")
+                with InsertionPoint(op.body):
+                    f32 = irdl.is_(TypeAttr.get(F32Type.get()))
+                    irdl.operands_([f32], ["input"], [irdl.Variadicity.single])
+
+        # CHECK: module {
+        # CHECK:   irdl.dialect @irdl_test {
+        # CHECK:     irdl.operation @test_op {
+        # CHECK:       %0 = irdl.is f32
+        # CHECK:       irdl.operands(input: %0)
+        # CHECK:     }
+        # CHECK:   }
+        # CHECK: }
+        module.dump()
+
+        irdl.load_dialects(module)
+
+        m = Module.parse("""
+          module {
+            %a = arith.constant 1.0 : f32
+            "irdl_test.test_op"(%a) : (f32) -> ()
+          }
+        """)
+        # CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
+        m.dump()
``````````
</details>
https://github.com/llvm/llvm-project/pull/158488
    
    
More information about the Mlir-commits
mailing list