[Mlir-commits] [mlir] [MLIR][Python] Add support of `convert_region_types` and the bf integration test (PR #183664)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 26 18:06:58 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Twice (PragmaTwice)
<details>
<summary>Changes</summary>
This PR adds the `convert_region_types` API to `ConversionPatternRewriter` and introduces a new integration test, `bf.py`, which demonstrates how to combine a Python-defined dialect, the dialect conversion API, the pass manager, and the execution engine to build a pure-Python JIT compilation pipeline.
---
Full diff: https://github.com/llvm/llvm-project/pull/183664.diff
4 Files Affected:
- (modified) mlir/include/mlir-c/Rewrite.h (+6)
- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+12-3)
- (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+7)
- (added) mlir/test/python/integration/dialects/bf.py (+252)
``````````diff
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 5e952edad23cb..6947158f624c0 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -490,6 +490,12 @@ MLIR_CAPI_EXPORTED MlirPatternRewriter
mlirConversionPatternRewriterAsPatternRewriter(
MlirConversionPatternRewriter rewriter);
+/// Apply a signature conversion to each block in the given region.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirConversionPatternRewriterConvertRegionTypes(
+ MlirConversionPatternRewriter rewriter, MlirRegion region,
+ MlirTypeConverter typeConverter);
+
//===----------------------------------------------------------------------===//
/// ConversionTarget API
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 6c414e1a4c023..181df847f36fe 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -35,11 +35,14 @@ class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
: PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
};
-class PyConversionPatternRewriter : PyPatternRewriter {
+class PyConversionPatternRewriter : public PyPatternRewriter {
public:
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
: PyPatternRewriter(
- mlirConversionPatternRewriterAsPatternRewriter(rewriter)) {}
+ mlirConversionPatternRewriterAsPatternRewriter(rewriter)),
+ rewriter(rewriter) {}
+
+ MlirConversionPatternRewriter rewriter;
};
class PyConversionTarget {
@@ -577,7 +580,13 @@ void populateRewriteSubmodule(nb::module_ &m) {
"Freeze the pattern set into a frozen one.");
nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
- m, "ConversionPatternRewriter");
+ m, "ConversionPatternRewriter")
+ .def("convert_region_types",
+ [](PyConversionPatternRewriter &self, PyRegion ®ion,
+ PyTypeConverter &typeConverter) {
+ mlirConversionPatternRewriterConvertRegionTypes(
+ self.rewriter, region.get(), typeConverter.get());
+ });
nb::class_<PyConversionTarget>(m, "ConversionTarget")
.def(
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a7e43254767ad..5a6ade352d760 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -527,6 +527,13 @@ MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(
return wrap(static_cast<mlir::PatternRewriter *>(unwrap(rewriter)));
}
+MlirLogicalResult mlirConversionPatternRewriterConvertRegionTypes(
+ MlirConversionPatternRewriter rewriter, MlirRegion region,
+ MlirTypeConverter typeConverter) {
+ return wrap(unwrap(rewriter)->convertRegionTypes(unwrap(region),
+ *unwrap(typeConverter)));
+}
+
//===----------------------------------------------------------------------===//
/// ConversionTarget API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
new file mode 100644
index 0000000000000..6eb55c3249c46
--- /dev/null
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -0,0 +1,252 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+# REQUIRES: host-supports-jit
+
+from mlir.ir import *
+from mlir.dialects.ext import *
+from mlir.rewrite import *
+from mlir.passmanager import *
+from mlir.execution_engine import *
+from mlir.dialects import llvm, scf, func
+from functools import partial
+
+
+class BfDialect(Dialect, name="bf"):
+ pass
+
+
+class PtrType(BfDialect.Type, name="ptr"):
+ pass
+
+
+class Ptr2Type(BfDialect.Type, name="ptr2"):
+ pass
+
+
+class NextOp(BfDialect.Operation, name="next"):
+ in_: Operand[PtrType]
+ out: Result[PtrType[()]]
+
+
+class PrevOp(BfDialect.Operation, name="prev"):
+ in_: Operand[PtrType]
+ out: Result[PtrType[()]]
+
+
+class IncOp(BfDialect.Operation, name="inc"):
+ in_: Operand[PtrType]
+
+
+class DecOp(BfDialect.Operation, name="dec"):
+ in_: Operand[PtrType]
+
+
+class InputOp(BfDialect.Operation, name="input"):
+ in_: Operand[PtrType]
+
+
+class OutputOp(BfDialect.Operation, name="output"):
+ in_: Operand[PtrType]
+
+
+class WhileOp(BfDialect.Operation, name="while"):
+ in_: Operand[PtrType]
+ out: Result[PtrType[()]]
+ body: Region
+
+
+class YieldOp(BfDialect.Operation, name="yield", traits=[IsTerminatorTrait]):
+ in_: Operand[PtrType]
+
+
+class MainOp(BfDialect.Operation, name="main"):
+ body: Region
+
+
+def parse(code: str):
+ module = Module.create()
+
+ with InsertionPoint(module.body):
+ main = MainOp()
+ main.body.blocks.append()
+ current_val = main.body.blocks[0].add_argument(
+ PtrType.get(), Location.unknown()
+ )
+
+ ip = InsertionPoint(main.body.blocks[0])
+ for c in code:
+ with ip:
+ if c == ">":
+ current_val = NextOp(current_val).out
+ elif c == "<":
+ current_val = PrevOp(current_val).out
+ elif c == "+":
+ IncOp(current_val)
+ elif c == "-":
+ DecOp(current_val)
+ elif c == ".":
+ OutputOp(current_val)
+ elif c == ",":
+ InputOp(current_val)
+ elif c == "[":
+ loop = WhileOp(current_val)
+ loop.body.blocks.append()
+ current_val = loop.body.blocks[0].add_argument(
+ PtrType.get(), Location.unknown()
+ )
+ ip = InsertionPoint(loop.body.blocks[0])
+ elif c == "]":
+ YieldOp(current_val)
+ current_val = ip.block.owner.opview.out
+ ip = InsertionPoint.after(current_val.owner)
+
+ with ip:
+ YieldOp(current_val)
+
+ return module
+
+
+def convert_bf_to_llvm(op, pass_):
+ patterns = RewritePatternSet()
+ ptr = llvm.PointerType.get()
+ i8 = IntegerType.get_signless(8)
+ i32 = IntegerType.get_signless(32)
+
+ type_converter = TypeConverter()
+
+ def convert_ptr(t):
+ return ptr if isinstance(t, PtrType) else None
+
+ type_converter.add_conversion(convert_ptr)
+
+ def convert_next(op, adaptor, converter, rewriter, offset=1):
+ with rewriter.ip:
+ gep = llvm.GEPOp(ptr, adaptor.in_, [], [offset], i8, [])
+ rewriter.replace_op(op, gep)
+
+ def convert_inc(op, adaptor, converter, rewriter, cst=1):
+ with rewriter.ip:
+ load = llvm.load(i8, adaptor.in_)
+ one = llvm.mlir_constant(IntegerAttr.get(i8, cst))
+ added = llvm.add(load, one, [])
+ store = llvm.StoreOp(added, adaptor.in_)
+ rewriter.replace_op(op, store)
+
+ def convert_main(op, adaptor, converter, rewriter):
+ with rewriter.ip:
+ fn = func.FuncOp("bf_main", FunctionType.get([ptr], [ptr]))
+ op.body.blocks[0].append_to(fn.body)
+ rewriter.convert_region_types(fn.body, converter)
+ rewriter.replace_op(op, fn)
+
+ def convert_yield(op, adaptor, converter, rewriter):
+ with rewriter.ip:
+ if isinstance(op.parent.opview, WhileOp | scf.WhileOp):
+ yield_ = scf.YieldOp([adaptor.in_])
+ else:
+ yield_ = func.ReturnOp([adaptor.in_])
+ rewriter.replace_op(op, yield_)
+
+ def convert_while(op, adaptor, converter, rewriter):
+ with rewriter.ip:
+ loop = scf.WhileOp([ptr], [adaptor.in_])
+ loop.before.blocks.append()
+ arg = loop.before.blocks[0].add_argument(ptr, Location.unknown())
+ with InsertionPoint(loop.before.blocks[0]):
+ c = llvm.load(i8, arg)
+ zero = llvm.mlir_constant(IntegerAttr.get(i8, 0))
+ cond = llvm.icmp(llvm.ICmpPredicate.ne, c, zero)
+ scf.ConditionOp(cond, [arg])
+ op.body.blocks[0].append_to(loop.after)
+ rewriter.convert_region_types(loop.after, converter)
+ rewriter.replace_op(op, loop)
+
+ def convert_output(op, adaptor, converter, rewriter):
+ with rewriter.ip:
+ val = llvm.load(i8, adaptor.in_)
+ call = func.CallOp([], "bf_output", [val])
+ rewriter.replace_op(op, call)
+
+ def convert_input(op, adaptor, converter, rewriter):
+ with rewriter.ip:
+ call = func.call([i8], "bf_input", [])
+ store = llvm.StoreOp(call, adaptor.in_)
+ rewriter.replace_op(op, store)
+
+ patterns.add_conversion(NextOp, convert_next, type_converter)
+ patterns.add_conversion(PrevOp, partial(convert_next, offset=-1), type_converter)
+ patterns.add_conversion(IncOp, convert_inc, type_converter)
+ patterns.add_conversion(DecOp, partial(convert_inc, cst=-1), type_converter)
+ patterns.add_conversion(MainOp, convert_main, type_converter)
+ patterns.add_conversion(YieldOp, convert_yield, type_converter)
+ patterns.add_conversion(WhileOp, convert_while, type_converter)
+ patterns.add_conversion(OutputOp, convert_output, type_converter)
+ patterns.add_conversion(InputOp, convert_input, type_converter)
+
+ target = ConversionTarget()
+ target.add_illegal_dialect(BfDialect)
+
+ config = ConversionConfig()
+ config.build_materializations = False
+
+ apply_partial_conversion(op, target, patterns.freeze(), config)
+
+ with InsertionPoint(op.opview.body):
+ func.FuncOp("putchar", FunctionType.get([i32], [i32]), visibility="private")
+ func.FuncOp("getchar", FunctionType.get([], [i32]), visibility="private")
+
+ output = func.FuncOp("bf_output", FunctionType.get([i8], []))
+ output.body.blocks.append()
+ arg = output.body.blocks[0].add_argument(i8, Location.unknown())
+ with InsertionPoint(output.body.blocks[0]):
+ sext = llvm.sext(i32, arg)
+ func.call([i32], "putchar", [sext])
+ func.ReturnOp([])
+
+ input = func.FuncOp("bf_input", FunctionType.get([], [i8]))
+ input.body.blocks.append()
+ with InsertionPoint(input.body.blocks[0]):
+ call = func.call([i32], "getchar", [])
+ trunc = llvm.trunc(i8, call, [])
+ func.ReturnOp([trunc])
+
+ init = func.FuncOp("bf_init", FunctionType.get([], []))
+ init.attributes["llvm.emit_c_interface"] = UnitAttr.get()
+ init.body.blocks.append()
+ with InsertionPoint(init.body.blocks[0]):
+ c1024 = llvm.mlir_constant(IntegerAttr.get(i32, 1024))
+ zero = llvm.mlir_constant(IntegerAttr.get(i8, 0))
+ p = llvm.alloca(ptr, c1024, i8)
+ llvm.intr_memset(p, zero, c1024, False)
+ func.call([ptr], "bf_main", [p])
+ func.ReturnOp([])
+
+
+def execute(code):
+ module = parse(code)
+ assert module.operation.verify()
+
+ pm = PassManager()
+ pm.add(convert_bf_to_llvm)
+ pm.add("convert-scf-to-cf, convert-to-llvm")
+
+ pm.run(module.operation)
+
+ ee = ExecutionEngine(module)
+ ee.lookup("bf_init")(0)
+
+
+def run(f):
+ print("TEST:", f.__name__)
+ f()
+
+
+# CHECK: TEST: test_bf
+ at run
+def test_bf():
+ with Context(), Location.unknown():
+ BfDialect.load()
+
+ # CHECK: Hello World!
+ execute(
+ "++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.>---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++."
+ )
``````````
</details>
https://github.com/llvm/llvm-project/pull/183664
More information about the Mlir-commits
mailing list