[Mlir-commits] [mlir] [MLIR][Python] Add support of `convert_region_types` and the bf integration test (PR #183664)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 27 04:51:40 PST 2026


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/183664

>From d95cc8dc9491efa5e253ddbc1e764318d5854467 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 26 Feb 2026 22:07:07 +0800
Subject: [PATCH 1/4] [MLIR][Python] Support op adaptor for Python-defined
 operations

---
 mlir/python/mlir/dialects/ext.py | 50 +++++++++++++++++++++++++++++++-
 mlir/test/python/dialects/ext.py |  8 +++++
 2 files changed, 57 insertions(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 39aacf32dabb9..d88e25cced8f6 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -41,7 +41,17 @@
 Region = ir.Region
 
 register_dialect = _cext.register_dialect
-register_operation = _cext.register_operation
+
+
+def register_operation(dialect_cls: type) -> Callable[[type], type]:
+    register = _cext.register_operation(dialect_cls)
+
+    def decorator(op_cls: type) -> type:
+        register(op_cls)
+        _cext.register_op_adaptor(op_cls)(op_cls.Adaptor)
+        return op_cls
+
+    return decorator
 
 
 def construct_instance(origin, args):
@@ -307,6 +317,13 @@ def __init_subclass__(
         cls._generate_result_properties(results)
         cls._generate_region_properties(regions)
 
+        cls.Adaptor = type(
+            "Adaptor",
+            (OperationAdator,),
+            dict(),
+            operation=cls,
+        )
+
         dialect_obj.operations.append(cls)
 
     @staticmethod
@@ -507,6 +524,37 @@ def _emit_operation(cls) -> None:
                 )
 
 
+class OperationAdator(ir.OpAdaptor):
+    @classmethod
+    def __init_subclass__(cls, *, operation: type):
+        cls.OPERATION_NAME = operation.OPERATION_NAME
+        cls._operation_cls = operation
+
+        operands, attrs, results, regions = partition_fields(operation._fields)
+
+        for attr in attrs:
+            setattr(
+                cls,
+                attr.name,
+                property(lambda self, name=attr.name: self.attributes[name]),
+            )
+
+        for i, operand in enumerate(operands):
+            if operation._ODS_OPERAND_SEGMENTS:
+
+                def getter(self, i=i, operand=operand):
+                    operand_range = segmented_accessor(
+                        self.operands,
+                        self.attributes["operandSegmentSizes"],
+                        i,
+                    )
+                    return normalize_value_range(operand_range, operand.variadicity)
+
+                setattr(cls, operand.name, property(getter))
+            else:
+                setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
+
+
 @dataclass
 class ParamDef:
     name: str
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index f9252bad37a39..5b3f9d8416517 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -91,6 +91,14 @@ class AddOp(Operation, dialect=MyInt, name="add"):
         # CHECK: (self, /, value, *, loc=None, ip=None)
         print(ConstantOp.__init__.__signature__)
 
+        # CHECK: True
+        print(issubclass(AddOp.Adaptor, OpAdaptor))
+        adaptor1 = AddOp.Adaptor(list(add1.operands), add1)
+        # CHECK: OpResult(%0 = "myint.constant"() {value = 2 : i32} : () -> i32)
+        print(adaptor1.lhs)
+        # CHECK: OpResult(%1 = "myint.constant"() {value = 3 : i32} : () -> i32)
+        print(adaptor1.rhs)
+
 
 # CHECK: TEST: testExtDialect
 @run

>From 5a5b254f50eb3e7756e057297eacc5b9855f72f8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 26 Feb 2026 22:08:43 +0800
Subject: [PATCH 2/4] append

---
 mlir/test/python/dialects/ext.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 5b3f9d8416517..2921615e75d54 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -94,6 +94,8 @@ class AddOp(Operation, dialect=MyInt, name="add"):
         # CHECK: True
         print(issubclass(AddOp.Adaptor, OpAdaptor))
         adaptor1 = AddOp.Adaptor(list(add1.operands), add1)
+        # CHECK: myint.add
+        print(adaptor1.OPERATION_NAME)
         # CHECK: OpResult(%0 = "myint.constant"() {value = 2 : i32} : () -> i32)
         print(adaptor1.lhs)
         # CHECK: OpResult(%1 = "myint.constant"() {value = 3 : i32} : () -> i32)

>From 4932c88a958ce8b2d6d5b34b8b1d04104440abfe Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 26 Feb 2026 23:31:40 +0800
Subject: [PATCH 3/4] [MLIR][Python] Add support of convert_region_types and
 the bf integration test

---
 mlir/include/mlir-c/Rewrite.h               |   6 +
 mlir/lib/Bindings/Python/Rewrite.cpp        |  15 +-
 mlir/lib/CAPI/Transforms/Rewrite.cpp        |   7 +
 mlir/test/python/integration/dialects/bf.py | 252 ++++++++++++++++++++
 4 files changed, 277 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/python/integration/dialects/bf.py

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index b4f93fd5a9b78..cc1be8b91a481 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -490,6 +490,12 @@ MLIR_CAPI_EXPORTED MlirPatternRewriter
 mlirConversionPatternRewriterAsPatternRewriter(
     MlirConversionPatternRewriter rewriter);
 
+/// Apply a signature conversion to each block in the given region.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirConversionPatternRewriterConvertRegionTypes(
+    MlirConversionPatternRewriter rewriter, MlirRegion region,
+    MlirTypeConverter typeConverter);
+
 //===----------------------------------------------------------------------===//
 /// ConversionTarget API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index e370552c00a9a..22eaccce75883 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -35,11 +35,14 @@ class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
       : PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
 };
 
-class PyConversionPatternRewriter : PyPatternRewriter {
+class PyConversionPatternRewriter : public PyPatternRewriter {
 public:
   PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
       : PyPatternRewriter(
-            mlirConversionPatternRewriterAsPatternRewriter(rewriter)) {}
+            mlirConversionPatternRewriterAsPatternRewriter(rewriter)),
+        rewriter(rewriter) {}
+
+  MlirConversionPatternRewriter rewriter;
 };
 
 class PyConversionTarget {
@@ -568,7 +571,13 @@ void populateRewriteSubmodule(nb::module_ &m) {
            "Freeze the pattern set into a frozen one.");
 
   nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
-      m, "ConversionPatternRewriter");
+      m, "ConversionPatternRewriter")
+      .def("convert_region_types",
+           [](PyConversionPatternRewriter &self, PyRegion &region,
+              PyTypeConverter &typeConverter) {
+             mlirConversionPatternRewriterConvertRegionTypes(
+                 self.rewriter, region.get(), typeConverter.get());
+           });
 
   nb::class_<PyConversionTarget>(m, "ConversionTarget")
       .def(
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 5900f08ae1730..1a3ab2f66382a 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -527,6 +527,13 @@ MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(
   return wrap(static_cast<mlir::PatternRewriter *>(unwrap(rewriter)));
 }
 
+MlirLogicalResult mlirConversionPatternRewriterConvertRegionTypes(
+    MlirConversionPatternRewriter rewriter, MlirRegion region,
+    MlirTypeConverter typeConverter) {
+  return wrap(unwrap(rewriter)->convertRegionTypes(unwrap(region),
+                                                   *unwrap(typeConverter)));
+}
+
 //===----------------------------------------------------------------------===//
 /// ConversionTarget API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
new file mode 100644
index 0000000000000..6eb55c3249c46
--- /dev/null
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -0,0 +1,252 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+# REQUIRES: host-supports-jit
+
+from mlir.ir import *
+from mlir.dialects.ext import *
+from mlir.rewrite import *
+from mlir.passmanager import *
+from mlir.execution_engine import *
+from mlir.dialects import llvm, scf, func
+from functools import partial
+
+
+class BfDialect(Dialect, name="bf"):
+    pass
+
+
+class PtrType(BfDialect.Type, name="ptr"):
+    pass
+
+
+class Ptr2Type(BfDialect.Type, name="ptr2"):
+    pass
+
+
+class NextOp(BfDialect.Operation, name="next"):
+    in_: Operand[PtrType]
+    out: Result[PtrType[()]]
+
+
+class PrevOp(BfDialect.Operation, name="prev"):
+    in_: Operand[PtrType]
+    out: Result[PtrType[()]]
+
+
+class IncOp(BfDialect.Operation, name="inc"):
+    in_: Operand[PtrType]
+
+
+class DecOp(BfDialect.Operation, name="dec"):
+    in_: Operand[PtrType]
+
+
+class InputOp(BfDialect.Operation, name="input"):
+    in_: Operand[PtrType]
+
+
+class OutputOp(BfDialect.Operation, name="output"):
+    in_: Operand[PtrType]
+
+
+class WhileOp(BfDialect.Operation, name="while"):
+    in_: Operand[PtrType]
+    out: Result[PtrType[()]]
+    body: Region
+
+
+class YieldOp(BfDialect.Operation, name="yield", traits=[IsTerminatorTrait]):
+    in_: Operand[PtrType]
+
+
+class MainOp(BfDialect.Operation, name="main"):
+    body: Region
+
+
+def parse(code: str):
+    module = Module.create()
+
+    with InsertionPoint(module.body):
+        main = MainOp()
+        main.body.blocks.append()
+        current_val = main.body.blocks[0].add_argument(
+            PtrType.get(), Location.unknown()
+        )
+
+        ip = InsertionPoint(main.body.blocks[0])
+        for c in code:
+            with ip:
+                if c == ">":
+                    current_val = NextOp(current_val).out
+                elif c == "<":
+                    current_val = PrevOp(current_val).out
+                elif c == "+":
+                    IncOp(current_val)
+                elif c == "-":
+                    DecOp(current_val)
+                elif c == ".":
+                    OutputOp(current_val)
+                elif c == ",":
+                    InputOp(current_val)
+                elif c == "[":
+                    loop = WhileOp(current_val)
+                    loop.body.blocks.append()
+                    current_val = loop.body.blocks[0].add_argument(
+                        PtrType.get(), Location.unknown()
+                    )
+                    ip = InsertionPoint(loop.body.blocks[0])
+                elif c == "]":
+                    YieldOp(current_val)
+                    current_val = ip.block.owner.opview.out
+                    ip = InsertionPoint.after(current_val.owner)
+
+        with ip:
+            YieldOp(current_val)
+
+    return module
+
+
+def convert_bf_to_llvm(op, pass_):
+    patterns = RewritePatternSet()
+    ptr = llvm.PointerType.get()
+    i8 = IntegerType.get_signless(8)
+    i32 = IntegerType.get_signless(32)
+
+    type_converter = TypeConverter()
+
+    def convert_ptr(t):
+        return ptr if isinstance(t, PtrType) else None
+
+    type_converter.add_conversion(convert_ptr)
+
+    def convert_next(op, adaptor, converter, rewriter, offset=1):
+        with rewriter.ip:
+            gep = llvm.GEPOp(ptr, adaptor.in_, [], [offset], i8, [])
+        rewriter.replace_op(op, gep)
+
+    def convert_inc(op, adaptor, converter, rewriter, cst=1):
+        with rewriter.ip:
+            load = llvm.load(i8, adaptor.in_)
+            one = llvm.mlir_constant(IntegerAttr.get(i8, cst))
+            added = llvm.add(load, one, [])
+            store = llvm.StoreOp(added, adaptor.in_)
+        rewriter.replace_op(op, store)
+
+    def convert_main(op, adaptor, converter, rewriter):
+        with rewriter.ip:
+            fn = func.FuncOp("bf_main", FunctionType.get([ptr], [ptr]))
+            op.body.blocks[0].append_to(fn.body)
+            rewriter.convert_region_types(fn.body, converter)
+        rewriter.replace_op(op, fn)
+
+    def convert_yield(op, adaptor, converter, rewriter):
+        with rewriter.ip:
+            if isinstance(op.parent.opview, WhileOp | scf.WhileOp):
+                yield_ = scf.YieldOp([adaptor.in_])
+            else:
+                yield_ = func.ReturnOp([adaptor.in_])
+        rewriter.replace_op(op, yield_)
+
+    def convert_while(op, adaptor, converter, rewriter):
+        with rewriter.ip:
+            loop = scf.WhileOp([ptr], [adaptor.in_])
+            loop.before.blocks.append()
+            arg = loop.before.blocks[0].add_argument(ptr, Location.unknown())
+            with InsertionPoint(loop.before.blocks[0]):
+                c = llvm.load(i8, arg)
+                zero = llvm.mlir_constant(IntegerAttr.get(i8, 0))
+                cond = llvm.icmp(llvm.ICmpPredicate.ne, c, zero)
+                scf.ConditionOp(cond, [arg])
+            op.body.blocks[0].append_to(loop.after)
+            rewriter.convert_region_types(loop.after, converter)
+        rewriter.replace_op(op, loop)
+
+    def convert_output(op, adaptor, converter, rewriter):
+        with rewriter.ip:
+            val = llvm.load(i8, adaptor.in_)
+            call = func.CallOp([], "bf_output", [val])
+        rewriter.replace_op(op, call)
+
+    def convert_input(op, adaptor, converter, rewriter):
+        with rewriter.ip:
+            call = func.call([i8], "bf_input", [])
+            store = llvm.StoreOp(call, adaptor.in_)
+        rewriter.replace_op(op, store)
+
+    patterns.add_conversion(NextOp, convert_next, type_converter)
+    patterns.add_conversion(PrevOp, partial(convert_next, offset=-1), type_converter)
+    patterns.add_conversion(IncOp, convert_inc, type_converter)
+    patterns.add_conversion(DecOp, partial(convert_inc, cst=-1), type_converter)
+    patterns.add_conversion(MainOp, convert_main, type_converter)
+    patterns.add_conversion(YieldOp, convert_yield, type_converter)
+    patterns.add_conversion(WhileOp, convert_while, type_converter)
+    patterns.add_conversion(OutputOp, convert_output, type_converter)
+    patterns.add_conversion(InputOp, convert_input, type_converter)
+
+    target = ConversionTarget()
+    target.add_illegal_dialect(BfDialect)
+
+    config = ConversionConfig()
+    config.build_materializations = False
+
+    apply_partial_conversion(op, target, patterns.freeze(), config)
+
+    with InsertionPoint(op.opview.body):
+        func.FuncOp("putchar", FunctionType.get([i32], [i32]), visibility="private")
+        func.FuncOp("getchar", FunctionType.get([], [i32]), visibility="private")
+
+        output = func.FuncOp("bf_output", FunctionType.get([i8], []))
+        output.body.blocks.append()
+        arg = output.body.blocks[0].add_argument(i8, Location.unknown())
+        with InsertionPoint(output.body.blocks[0]):
+            sext = llvm.sext(i32, arg)
+            func.call([i32], "putchar", [sext])
+            func.ReturnOp([])
+
+        input = func.FuncOp("bf_input", FunctionType.get([], [i8]))
+        input.body.blocks.append()
+        with InsertionPoint(input.body.blocks[0]):
+            call = func.call([i32], "getchar", [])
+            trunc = llvm.trunc(i8, call, [])
+            func.ReturnOp([trunc])
+
+        init = func.FuncOp("bf_init", FunctionType.get([], []))
+        init.attributes["llvm.emit_c_interface"] = UnitAttr.get()
+        init.body.blocks.append()
+        with InsertionPoint(init.body.blocks[0]):
+            c1024 = llvm.mlir_constant(IntegerAttr.get(i32, 1024))
+            zero = llvm.mlir_constant(IntegerAttr.get(i8, 0))
+            p = llvm.alloca(ptr, c1024, i8)
+            llvm.intr_memset(p, zero, c1024, False)
+            func.call([ptr], "bf_main", [p])
+            func.ReturnOp([])
+
+
+def execute(code):
+    module = parse(code)
+    assert module.operation.verify()
+
+    pm = PassManager()
+    pm.add(convert_bf_to_llvm)
+    pm.add("convert-scf-to-cf, convert-to-llvm")
+
+    pm.run(module.operation)
+
+    ee = ExecutionEngine(module)
+    ee.lookup("bf_init")(0)
+
+
+def run(f):
+    print("TEST:", f.__name__)
+    f()
+
+
+# CHECK: TEST: test_bf
+ at run
+def test_bf():
+    with Context(), Location.unknown():
+        BfDialect.load()
+
+        # CHECK: Hello World!
+        execute(
+            "++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.>---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++."
+        )

>From ffd4421b979830eb6b0c969d9ce91fa00ec8eec1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 27 Feb 2026 20:51:21 +0800
Subject: [PATCH 4/4] refine

---
 mlir/test/python/integration/dialects/bf.py | 9 +--------
 1 file changed, 1 insertion(+), 8 deletions(-)

diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
index 6eb55c3249c46..2dcf4f7a15363 100644
--- a/mlir/test/python/integration/dialects/bf.py
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -18,10 +18,6 @@ class PtrType(BfDialect.Type, name="ptr"):
     pass
 
 
-class Ptr2Type(BfDialect.Type, name="ptr2"):
-    pass
-
-
 class NextOp(BfDialect.Operation, name="next"):
     in_: Operand[PtrType]
     out: Result[PtrType[()]]
@@ -185,10 +181,7 @@ def convert_input(op, adaptor, converter, rewriter):
     target = ConversionTarget()
     target.add_illegal_dialect(BfDialect)
 
-    config = ConversionConfig()
-    config.build_materializations = False
-
-    apply_partial_conversion(op, target, patterns.freeze(), config)
+    apply_partial_conversion(op, target, patterns.freeze())
 
     with InsertionPoint(op.opview.body):
         func.FuncOp("putchar", FunctionType.get([i32], [i32]), visibility="private")



More information about the Mlir-commits mailing list