[Mlir-commits] [mlir] 4673cec - [MLIR][Python] Add support of `convert_region_types` and the bf integration test (#183664)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 28 17:55:38 PST 2026


Author: Twice
Date: 2026-03-01T09:55:33+08:00
New Revision: 4673cecc89d1a2260edcff2808ed46de7bc0435d

URL: https://github.com/llvm/llvm-project/commit/4673cecc89d1a2260edcff2808ed46de7bc0435d
DIFF: https://github.com/llvm/llvm-project/commit/4673cecc89d1a2260edcff2808ed46de7bc0435d.diff

LOG: [MLIR][Python] Add support of `convert_region_types` and the bf integration test (#183664)

This PR adds the `convert_region_types` API to
`ConversionPatternRewriter` and introduces a new integration test,
`bf.py`, which demonstrates how to combine a Python-defined dialect, the
dialect conversion API, the pass manager, and the execution engine to
build a pure-Python JIT compilation pipeline.

Added: 
    mlir/test/python/integration/dialects/bf.py

Modified: 
    mlir/include/mlir-c/Rewrite.h
    mlir/lib/Bindings/Python/Rewrite.cpp
    mlir/lib/CAPI/Transforms/Rewrite.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 5e952edad23cb..6947158f624c0 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -490,6 +490,12 @@ MLIR_CAPI_EXPORTED MlirPatternRewriter
 mlirConversionPatternRewriterAsPatternRewriter(
     MlirConversionPatternRewriter rewriter);
 
+/// Apply a signature conversion to each block in the given region.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirConversionPatternRewriterConvertRegionTypes(
+    MlirConversionPatternRewriter rewriter, MlirRegion region,
+    MlirTypeConverter typeConverter);
+
 //===----------------------------------------------------------------------===//
 /// ConversionTarget API
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 6c414e1a4c023..181df847f36fe 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -35,11 +35,14 @@ class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
       : PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
 };
 
-class PyConversionPatternRewriter : PyPatternRewriter {
+class PyConversionPatternRewriter : public PyPatternRewriter {
 public:
   PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
       : PyPatternRewriter(
-            mlirConversionPatternRewriterAsPatternRewriter(rewriter)) {}
+            mlirConversionPatternRewriterAsPatternRewriter(rewriter)),
+        rewriter(rewriter) {}
+
+  MlirConversionPatternRewriter rewriter;
 };
 
 class PyConversionTarget {
@@ -577,7 +580,13 @@ void populateRewriteSubmodule(nb::module_ &m) {
            "Freeze the pattern set into a frozen one.");
 
   nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
-      m, "ConversionPatternRewriter");
+      m, "ConversionPatternRewriter")
+      .def("convert_region_types",
+           [](PyConversionPatternRewriter &self, PyRegion &region,
+              PyTypeConverter &typeConverter) {
+             mlirConversionPatternRewriterConvertRegionTypes(
+                 self.rewriter, region.get(), typeConverter.get());
+           });
 
   nb::class_<PyConversionTarget>(m, "ConversionTarget")
       .def(

diff  --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a7e43254767ad..5a6ade352d760 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -527,6 +527,13 @@ MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(
   return wrap(static_cast<mlir::PatternRewriter *>(unwrap(rewriter)));
 }
 
+MlirLogicalResult mlirConversionPatternRewriterConvertRegionTypes(
+    MlirConversionPatternRewriter rewriter, MlirRegion region,
+    MlirTypeConverter typeConverter) {
+  return wrap(unwrap(rewriter)->convertRegionTypes(unwrap(region),
+                                                   *unwrap(typeConverter)));
+}
+
 //===----------------------------------------------------------------------===//
 /// ConversionTarget API
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
new file mode 100644
index 0000000000000..cbb815378ddf7
--- /dev/null
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -0,0 +1,284 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+# REQUIRES: host-supports-jit
+
+from mlir.ir import *
+from mlir.dialects.ext import *
+from mlir.rewrite import *
+from mlir.passmanager import *
+from mlir.execution_engine import *
+from mlir.dialects import llvm, scf, func
+from functools import partial
+
+
+class BfDialect(Dialect, name="bf"):
+    pass
+
+
+class PtrType(BfDialect.Type, name="ptr"):
+    pass
+
+
+class NextOp(BfDialect.Operation, name="next"):
+    in_: Operand[PtrType]
+    out: Result[PtrType[()]]
+
+
+class PrevOp(BfDialect.Operation, name="prev"):
+    in_: Operand[PtrType]
+    out: Result[PtrType[()]]
+
+
+class IncOp(BfDialect.Operation, name="inc"):
+    in_: Operand[PtrType]
+
+
+class DecOp(BfDialect.Operation, name="dec"):
+    in_: Operand[PtrType]
+
+
+class InputOp(BfDialect.Operation, name="input"):
+    in_: Operand[PtrType]
+
+
+class OutputOp(BfDialect.Operation, name="output"):
+    in_: Operand[PtrType]
+
+
+class WhileOp(BfDialect.Operation, name="while"):
+    in_: Operand[PtrType]
+    out: Result[PtrType[()]]
+    body: Region
+
+
+class YieldOp(BfDialect.Operation, name="yield", traits=[IsTerminatorTrait]):
+    in_: Operand[PtrType]
+
+
+class MainOp(BfDialect.Operation, name="main"):
+    body: Region
+
+
+def parse(code: str):
+    module = Module.create()
+
+    with InsertionPoint(module.body):
+        main = MainOp()
+        main.body.blocks.append()
+        current_val = main.body.blocks[0].add_argument(
+            PtrType.get(), Location.unknown()
+        )
+
+        ip = InsertionPoint(main.body.blocks[0])
+        for c in code:
+            with ip:
+                if c == ">":
+                    current_val = NextOp(current_val).out
+                elif c == "<":
+                    current_val = PrevOp(current_val).out
+                elif c == "+":
+                    IncOp(current_val)
+                elif c == "-":
+                    DecOp(current_val)
+                elif c == ".":
+                    OutputOp(current_val)
+                elif c == ",":
+                    InputOp(current_val)
+                elif c == "[":
+                    loop = WhileOp(current_val)
+                    loop.body.blocks.append()
+                    current_val = loop.body.blocks[0].add_argument(
+                        PtrType.get(), Location.unknown()
+                    )
+                    ip = InsertionPoint(loop.body.blocks[0])
+                elif c == "]":
+                    YieldOp(current_val)
+                    current_val = ip.block.owner.opview.out
+                    ip = InsertionPoint.after(current_val.owner)
+
+        with ip:
+            YieldOp(current_val)
+
+    return module
+
+
+def convert_bf_to_llvm(op, pass_):
+    patterns = RewritePatternSet()
+    ptr = llvm.PointerType.get()
+    i8 = IntegerType.get_signless(8)
+    i32 = IntegerType.get_signless(32)
+
+    type_converter = TypeConverter()
+
+    def convert_ptr(t):
+        return ptr if isinstance(t, PtrType) else None
+
+    type_converter.add_conversion(convert_ptr)
+
+    def convert_next(op, adaptor, converter, rewriter, offset=1):
+        with rewriter.ip:
+            gep = llvm.GEPOp(ptr, adaptor.in_, [], [offset], i8, [])
+        rewriter.replace_op(op, gep)
+
+    def convert_inc(op, adaptor, converter, rewriter, cst=1):
+        with rewriter.ip:
+            load = llvm.load(i8, adaptor.in_)
+            one = llvm.mlir_constant(IntegerAttr.get(i8, cst))
+            added = llvm.add(load, one, [])
+            store = llvm.StoreOp(added, adaptor.in_)
+        rewriter.replace_op(op, store)
+
+    def convert_main(op, adaptor, converter, rewriter):
+        with rewriter.ip:
+            fn = func.FuncOp("bf_main", FunctionType.get([ptr], [ptr]))
+            op.body.blocks[0].append_to(fn.body)
+            rewriter.convert_region_types(fn.body, converter)
+        rewriter.replace_op(op, fn)
+
+    def convert_yield(op, adaptor, converter, rewriter):
+        with rewriter.ip:
+            if isinstance(op.parent.opview, WhileOp | scf.WhileOp):
+                yield_ = scf.YieldOp([adaptor.in_])
+            else:
+                yield_ = func.ReturnOp([adaptor.in_])
+        rewriter.replace_op(op, yield_)
+
+    def convert_while(op, adaptor, converter, rewriter):
+        with rewriter.ip:
+            loop = scf.WhileOp([ptr], [adaptor.in_])
+            loop.before.blocks.append()
+            arg = loop.before.blocks[0].add_argument(ptr, Location.unknown())
+            with InsertionPoint(loop.before.blocks[0]):
+                c = llvm.load(i8, arg)
+                zero = llvm.mlir_constant(IntegerAttr.get(i8, 0))
+                cond = llvm.icmp(llvm.ICmpPredicate.ne, c, zero)
+                scf.ConditionOp(cond, [arg])
+            op.body.blocks[0].append_to(loop.after)
+            rewriter.convert_region_types(loop.after, converter)
+        rewriter.replace_op(op, loop)
+
+    def convert_output(op, adaptor, converter, rewriter):
+        with rewriter.ip:
+            val = llvm.load(i8, adaptor.in_)
+            call = func.CallOp([], "bf_output", [val])
+        rewriter.replace_op(op, call)
+
+    def convert_input(op, adaptor, converter, rewriter):
+        with rewriter.ip:
+            call = func.call([i8], "bf_input", [])
+            store = llvm.StoreOp(call, adaptor.in_)
+        rewriter.replace_op(op, store)
+
+    patterns.add_conversion(NextOp, convert_next, type_converter)
+    patterns.add_conversion(PrevOp, partial(convert_next, offset=-1), type_converter)
+    patterns.add_conversion(IncOp, convert_inc, type_converter)
+    patterns.add_conversion(DecOp, partial(convert_inc, cst=-1), type_converter)
+    patterns.add_conversion(MainOp, convert_main, type_converter)
+    patterns.add_conversion(YieldOp, convert_yield, type_converter)
+    patterns.add_conversion(WhileOp, convert_while, type_converter)
+    patterns.add_conversion(OutputOp, convert_output, type_converter)
+    patterns.add_conversion(InputOp, convert_input, type_converter)
+
+    target = ConversionTarget()
+    target.add_illegal_dialect(BfDialect)
+
+    apply_partial_conversion(op, target, patterns.freeze())
+
+    with InsertionPoint(op.opview.body):
+        func.FuncOp("putchar", FunctionType.get([i32], [i32]), visibility="private")
+        func.FuncOp("getchar", FunctionType.get([], [i32]), visibility="private")
+
+        output = func.FuncOp("bf_output", FunctionType.get([i8], []))
+        output.body.blocks.append()
+        arg = output.body.blocks[0].add_argument(i8, Location.unknown())
+        with InsertionPoint(output.body.blocks[0]):
+            sext = llvm.sext(i32, arg)
+            func.call([i32], "putchar", [sext])
+            func.ReturnOp([])
+
+        input = func.FuncOp("bf_input", FunctionType.get([], [i8]))
+        input.body.blocks.append()
+        with InsertionPoint(input.body.blocks[0]):
+            call = func.call([i32], "getchar", [])
+            trunc = llvm.trunc(i8, call, [])
+            func.ReturnOp([trunc])
+
+        init = func.FuncOp("bf_init", FunctionType.get([], []))
+        init.attributes["llvm.emit_c_interface"] = UnitAttr.get()
+        init.body.blocks.append()
+        with InsertionPoint(init.body.blocks[0]):
+            c1024 = llvm.mlir_constant(IntegerAttr.get(i32, 1024))
+            zero = llvm.mlir_constant(IntegerAttr.get(i8, 0))
+            p = llvm.alloca(ptr, c1024, i8)
+            llvm.intr_memset(p, zero, c1024, False)
+            func.call([ptr], "bf_main", [p])
+            func.ReturnOp([])
+
+
+def execute(code):
+    module = parse(code)
+    assert module.operation.verify()
+
+    pm = PassManager()
+    pm.add(convert_bf_to_llvm)
+    pm.add("convert-scf-to-cf, convert-to-llvm")
+
+    pm.run(module.operation)
+
+    ee = ExecutionEngine(module)
+    ee.lookup("bf_init")(0)
+
+
+def run(f):
+    print("TEST:", f.__name__)
+    f()
+
+
+with Context(), Location.unknown():
+    BfDialect.load()
+
+    # CHECK: TEST: test_convert_bf_to_llvm
+    @run
+    def test_convert_bf_to_llvm():
+        module = parse("[-]")
+        assert module.operation.verify()
+
+        # CHECK: "bf.main"() ({
+        # CHECK: ^bb0(%arg0: !bf.ptr):
+        # CHECK:   %0 = "bf.while"(%arg0) ({
+        # CHECK:   ^bb0(%arg1: !bf.ptr):
+        # CHECK:     "bf.dec"(%arg1) : (!bf.ptr) -> ()
+        # CHECK:     "bf.yield"(%arg1) : (!bf.ptr) -> ()
+        # CHECK:   }) : (!bf.ptr) -> !bf.ptr
+        # CHECK:   "bf.yield"(%0) : (!bf.ptr) -> ()
+        # CHECK: }) : () -> ()
+        print(module)
+
+        pm = PassManager()
+        pm.add(convert_bf_to_llvm)
+        pm.run(module.operation)
+
+        # CHECK: func.func @bf_main(%arg0: !llvm.ptr) -> !llvm.ptr {
+        # CHECK:   %0 = scf.while (%arg1 = %arg0) : (!llvm.ptr) -> !llvm.ptr {
+        # CHECK:     %1 = llvm.load %arg1 : !llvm.ptr -> i8
+        # CHECK:     %2 = llvm.mlir.constant(0 : i8) : i8
+        # CHECK:     %3 = llvm.icmp "ne" %1, %2 : i8
+        # CHECK:     scf.condition(%3) %arg1 : !llvm.ptr
+        # CHECK:   } do {
+        # CHECK:   ^bb0(%arg1: !llvm.ptr):
+        # CHECK:     %1 = llvm.load %arg1 : !llvm.ptr -> i8
+        # CHECK:     %2 = llvm.mlir.constant(-1 : i8) : i8
+        # CHECK:     %3 = llvm.add %1, %2 : i8
+        # CHECK:     llvm.store %3, %arg1 : i8, !llvm.ptr
+        # CHECK:     scf.yield %arg1 : !llvm.ptr
+        # CHECK:   }
+        # CHECK:   return %0 : !llvm.ptr
+        # CHECK: }
+        print(module)
+
+    # CHECK: TEST: test_bf_e2e
+    @run
+    def test_bf_e2e():
+        # CHECK: Hello World!
+        execute(
+            "++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.>---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++."
+        )


        


More information about the Mlir-commits mailing list