[Mlir-commits] [mlir] [MLIR][Python] Add support of `convert_region_types` and the bf integration test (PR #183664)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 27 04:51:40 PST 2026
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/183664
>From d95cc8dc9491efa5e253ddbc1e764318d5854467 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 26 Feb 2026 22:07:07 +0800
Subject: [PATCH 1/4] [MLIR][Python] Support op adaptor for Python-defined
operations
---
mlir/python/mlir/dialects/ext.py | 50 +++++++++++++++++++++++++++++++-
mlir/test/python/dialects/ext.py | 8 +++++
2 files changed, 57 insertions(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 39aacf32dabb9..d88e25cced8f6 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -41,7 +41,17 @@
Region = ir.Region
register_dialect = _cext.register_dialect
-register_operation = _cext.register_operation
+
+
+def register_operation(dialect_cls: type) -> Callable[[type], type]:
+ register = _cext.register_operation(dialect_cls)
+
+ def decorator(op_cls: type) -> type:
+ register(op_cls)
+ _cext.register_op_adaptor(op_cls)(op_cls.Adaptor)
+ return op_cls
+
+ return decorator
def construct_instance(origin, args):
@@ -307,6 +317,13 @@ def __init_subclass__(
cls._generate_result_properties(results)
cls._generate_region_properties(regions)
+ cls.Adaptor = type(
+ "Adaptor",
+ (OperationAdator,),
+ dict(),
+ operation=cls,
+ )
+
dialect_obj.operations.append(cls)
@staticmethod
@@ -507,6 +524,37 @@ def _emit_operation(cls) -> None:
)
+class OperationAdator(ir.OpAdaptor):
+ @classmethod
+ def __init_subclass__(cls, *, operation: type):
+ cls.OPERATION_NAME = operation.OPERATION_NAME
+ cls._operation_cls = operation
+
+ operands, attrs, results, regions = partition_fields(operation._fields)
+
+ for attr in attrs:
+ setattr(
+ cls,
+ attr.name,
+ property(lambda self, name=attr.name: self.attributes[name]),
+ )
+
+ for i, operand in enumerate(operands):
+ if operation._ODS_OPERAND_SEGMENTS:
+
+ def getter(self, i=i, operand=operand):
+ operand_range = segmented_accessor(
+ self.operands,
+ self.attributes["operandSegmentSizes"],
+ i,
+ )
+ return normalize_value_range(operand_range, operand.variadicity)
+
+ setattr(cls, operand.name, property(getter))
+ else:
+ setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
+
+
@dataclass
class ParamDef:
name: str
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index f9252bad37a39..5b3f9d8416517 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -91,6 +91,14 @@ class AddOp(Operation, dialect=MyInt, name="add"):
# CHECK: (self, /, value, *, loc=None, ip=None)
print(ConstantOp.__init__.__signature__)
+ # CHECK: True
+ print(issubclass(AddOp.Adaptor, OpAdaptor))
+ adaptor1 = AddOp.Adaptor(list(add1.operands), add1)
+ # CHECK: OpResult(%0 = "myint.constant"() {value = 2 : i32} : () -> i32)
+ print(adaptor1.lhs)
+ # CHECK: OpResult(%1 = "myint.constant"() {value = 3 : i32} : () -> i32)
+ print(adaptor1.rhs)
+
# CHECK: TEST: testExtDialect
@run
>From 5a5b254f50eb3e7756e057297eacc5b9855f72f8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 26 Feb 2026 22:08:43 +0800
Subject: [PATCH 2/4] append
---
mlir/test/python/dialects/ext.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 5b3f9d8416517..2921615e75d54 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -94,6 +94,8 @@ class AddOp(Operation, dialect=MyInt, name="add"):
# CHECK: True
print(issubclass(AddOp.Adaptor, OpAdaptor))
adaptor1 = AddOp.Adaptor(list(add1.operands), add1)
+ # CHECK: myint.add
+ print(adaptor1.OPERATION_NAME)
# CHECK: OpResult(%0 = "myint.constant"() {value = 2 : i32} : () -> i32)
print(adaptor1.lhs)
# CHECK: OpResult(%1 = "myint.constant"() {value = 3 : i32} : () -> i32)
>From 4932c88a958ce8b2d6d5b34b8b1d04104440abfe Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 26 Feb 2026 23:31:40 +0800
Subject: [PATCH 3/4] [MLIR][Python] Add support of convert_region_types and
the bf integration test
---
mlir/include/mlir-c/Rewrite.h | 6 +
mlir/lib/Bindings/Python/Rewrite.cpp | 15 +-
mlir/lib/CAPI/Transforms/Rewrite.cpp | 7 +
mlir/test/python/integration/dialects/bf.py | 252 ++++++++++++++++++++
4 files changed, 277 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/python/integration/dialects/bf.py
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index b4f93fd5a9b78..cc1be8b91a481 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -490,6 +490,12 @@ MLIR_CAPI_EXPORTED MlirPatternRewriter
mlirConversionPatternRewriterAsPatternRewriter(
MlirConversionPatternRewriter rewriter);
+/// Apply a signature conversion to each block in the given region.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirConversionPatternRewriterConvertRegionTypes(
+ MlirConversionPatternRewriter rewriter, MlirRegion region,
+ MlirTypeConverter typeConverter);
+
//===----------------------------------------------------------------------===//
/// ConversionTarget API
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index e370552c00a9a..22eaccce75883 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -35,11 +35,14 @@ class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
: PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
};
-class PyConversionPatternRewriter : PyPatternRewriter {
+class PyConversionPatternRewriter : public PyPatternRewriter {
public:
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
: PyPatternRewriter(
- mlirConversionPatternRewriterAsPatternRewriter(rewriter)) {}
+ mlirConversionPatternRewriterAsPatternRewriter(rewriter)),
+ rewriter(rewriter) {}
+
+ MlirConversionPatternRewriter rewriter;
};
class PyConversionTarget {
@@ -568,7 +571,13 @@ void populateRewriteSubmodule(nb::module_ &m) {
"Freeze the pattern set into a frozen one.");
nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
- m, "ConversionPatternRewriter");
+ m, "ConversionPatternRewriter")
+ .def("convert_region_types",
+ [](PyConversionPatternRewriter &self, PyRegion ®ion,
+ PyTypeConverter &typeConverter) {
+ mlirConversionPatternRewriterConvertRegionTypes(
+ self.rewriter, region.get(), typeConverter.get());
+ });
nb::class_<PyConversionTarget>(m, "ConversionTarget")
.def(
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 5900f08ae1730..1a3ab2f66382a 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -527,6 +527,13 @@ MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(
return wrap(static_cast<mlir::PatternRewriter *>(unwrap(rewriter)));
}
+MlirLogicalResult mlirConversionPatternRewriterConvertRegionTypes(
+ MlirConversionPatternRewriter rewriter, MlirRegion region,
+ MlirTypeConverter typeConverter) {
+ return wrap(unwrap(rewriter)->convertRegionTypes(unwrap(region),
+ *unwrap(typeConverter)));
+}
+
//===----------------------------------------------------------------------===//
/// ConversionTarget API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
new file mode 100644
index 0000000000000..6eb55c3249c46
--- /dev/null
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -0,0 +1,252 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+# REQUIRES: host-supports-jit
+
+from mlir.ir import *
+from mlir.dialects.ext import *
+from mlir.rewrite import *
+from mlir.passmanager import *
+from mlir.execution_engine import *
+from mlir.dialects import llvm, scf, func
+from functools import partial
+
+
+class BfDialect(Dialect, name="bf"):
+ pass
+
+
+class PtrType(BfDialect.Type, name="ptr"):
+ pass
+
+
+class Ptr2Type(BfDialect.Type, name="ptr2"):
+ pass
+
+
+class NextOp(BfDialect.Operation, name="next"):
+ in_: Operand[PtrType]
+ out: Result[PtrType[()]]
+
+
+class PrevOp(BfDialect.Operation, name="prev"):
+ in_: Operand[PtrType]
+ out: Result[PtrType[()]]
+
+
+class IncOp(BfDialect.Operation, name="inc"):
+ in_: Operand[PtrType]
+
+
+class DecOp(BfDialect.Operation, name="dec"):
+ in_: Operand[PtrType]
+
+
+class InputOp(BfDialect.Operation, name="input"):
+ in_: Operand[PtrType]
+
+
+class OutputOp(BfDialect.Operation, name="output"):
+ in_: Operand[PtrType]
+
+
+class WhileOp(BfDialect.Operation, name="while"):
+ in_: Operand[PtrType]
+ out: Result[PtrType[()]]
+ body: Region
+
+
+class YieldOp(BfDialect.Operation, name="yield", traits=[IsTerminatorTrait]):
+ in_: Operand[PtrType]
+
+
+class MainOp(BfDialect.Operation, name="main"):
+ body: Region
+
+
+def parse(code: str):
+ module = Module.create()
+
+ with InsertionPoint(module.body):
+ main = MainOp()
+ main.body.blocks.append()
+ current_val = main.body.blocks[0].add_argument(
+ PtrType.get(), Location.unknown()
+ )
+
+ ip = InsertionPoint(main.body.blocks[0])
+ for c in code:
+ with ip:
+ if c == ">":
+ current_val = NextOp(current_val).out
+ elif c == "<":
+ current_val = PrevOp(current_val).out
+ elif c == "+":
+ IncOp(current_val)
+ elif c == "-":
+ DecOp(current_val)
+ elif c == ".":
+ OutputOp(current_val)
+ elif c == ",":
+ InputOp(current_val)
+ elif c == "[":
+ loop = WhileOp(current_val)
+ loop.body.blocks.append()
+ current_val = loop.body.blocks[0].add_argument(
+ PtrType.get(), Location.unknown()
+ )
+ ip = InsertionPoint(loop.body.blocks[0])
+ elif c == "]":
+ YieldOp(current_val)
+ current_val = ip.block.owner.opview.out
+ ip = InsertionPoint.after(current_val.owner)
+
+ with ip:
+ YieldOp(current_val)
+
+ return module
+
+
+def convert_bf_to_llvm(op, pass_):
+ patterns = RewritePatternSet()
+ ptr = llvm.PointerType.get()
+ i8 = IntegerType.get_signless(8)
+ i32 = IntegerType.get_signless(32)
+
+ type_converter = TypeConverter()
+
+ def convert_ptr(t):
+ return ptr if isinstance(t, PtrType) else None
+
+ type_converter.add_conversion(convert_ptr)
+
+ def convert_next(op, adaptor, converter, rewriter, offset=1):
+ with rewriter.ip:
+ gep = llvm.GEPOp(ptr, adaptor.in_, [], [offset], i8, [])
+ rewriter.replace_op(op, gep)
+
+ def convert_inc(op, adaptor, converter, rewriter, cst=1):
+ with rewriter.ip:
+ load = llvm.load(i8, adaptor.in_)
+ one = llvm.mlir_constant(IntegerAttr.get(i8, cst))
+ added = llvm.add(load, one, [])
+ store = llvm.StoreOp(added, adaptor.in_)
+ rewriter.replace_op(op, store)
+
+ def convert_main(op, adaptor, converter, rewriter):
+ with rewriter.ip:
+ fn = func.FuncOp("bf_main", FunctionType.get([ptr], [ptr]))
+ op.body.blocks[0].append_to(fn.body)
+ rewriter.convert_region_types(fn.body, converter)
+ rewriter.replace_op(op, fn)
+
+ def convert_yield(op, adaptor, converter, rewriter):
+ with rewriter.ip:
+ if isinstance(op.parent.opview, WhileOp | scf.WhileOp):
+ yield_ = scf.YieldOp([adaptor.in_])
+ else:
+ yield_ = func.ReturnOp([adaptor.in_])
+ rewriter.replace_op(op, yield_)
+
+ def convert_while(op, adaptor, converter, rewriter):
+ with rewriter.ip:
+ loop = scf.WhileOp([ptr], [adaptor.in_])
+ loop.before.blocks.append()
+ arg = loop.before.blocks[0].add_argument(ptr, Location.unknown())
+ with InsertionPoint(loop.before.blocks[0]):
+ c = llvm.load(i8, arg)
+ zero = llvm.mlir_constant(IntegerAttr.get(i8, 0))
+ cond = llvm.icmp(llvm.ICmpPredicate.ne, c, zero)
+ scf.ConditionOp(cond, [arg])
+ op.body.blocks[0].append_to(loop.after)
+ rewriter.convert_region_types(loop.after, converter)
+ rewriter.replace_op(op, loop)
+
+ def convert_output(op, adaptor, converter, rewriter):
+ with rewriter.ip:
+ val = llvm.load(i8, adaptor.in_)
+ call = func.CallOp([], "bf_output", [val])
+ rewriter.replace_op(op, call)
+
+ def convert_input(op, adaptor, converter, rewriter):
+ with rewriter.ip:
+ call = func.call([i8], "bf_input", [])
+ store = llvm.StoreOp(call, adaptor.in_)
+ rewriter.replace_op(op, store)
+
+ patterns.add_conversion(NextOp, convert_next, type_converter)
+ patterns.add_conversion(PrevOp, partial(convert_next, offset=-1), type_converter)
+ patterns.add_conversion(IncOp, convert_inc, type_converter)
+ patterns.add_conversion(DecOp, partial(convert_inc, cst=-1), type_converter)
+ patterns.add_conversion(MainOp, convert_main, type_converter)
+ patterns.add_conversion(YieldOp, convert_yield, type_converter)
+ patterns.add_conversion(WhileOp, convert_while, type_converter)
+ patterns.add_conversion(OutputOp, convert_output, type_converter)
+ patterns.add_conversion(InputOp, convert_input, type_converter)
+
+ target = ConversionTarget()
+ target.add_illegal_dialect(BfDialect)
+
+ config = ConversionConfig()
+ config.build_materializations = False
+
+ apply_partial_conversion(op, target, patterns.freeze(), config)
+
+ with InsertionPoint(op.opview.body):
+ func.FuncOp("putchar", FunctionType.get([i32], [i32]), visibility="private")
+ func.FuncOp("getchar", FunctionType.get([], [i32]), visibility="private")
+
+ output = func.FuncOp("bf_output", FunctionType.get([i8], []))
+ output.body.blocks.append()
+ arg = output.body.blocks[0].add_argument(i8, Location.unknown())
+ with InsertionPoint(output.body.blocks[0]):
+ sext = llvm.sext(i32, arg)
+ func.call([i32], "putchar", [sext])
+ func.ReturnOp([])
+
+ input = func.FuncOp("bf_input", FunctionType.get([], [i8]))
+ input.body.blocks.append()
+ with InsertionPoint(input.body.blocks[0]):
+ call = func.call([i32], "getchar", [])
+ trunc = llvm.trunc(i8, call, [])
+ func.ReturnOp([trunc])
+
+ init = func.FuncOp("bf_init", FunctionType.get([], []))
+ init.attributes["llvm.emit_c_interface"] = UnitAttr.get()
+ init.body.blocks.append()
+ with InsertionPoint(init.body.blocks[0]):
+ c1024 = llvm.mlir_constant(IntegerAttr.get(i32, 1024))
+ zero = llvm.mlir_constant(IntegerAttr.get(i8, 0))
+ p = llvm.alloca(ptr, c1024, i8)
+ llvm.intr_memset(p, zero, c1024, False)
+ func.call([ptr], "bf_main", [p])
+ func.ReturnOp([])
+
+
+def execute(code):
+ module = parse(code)
+ assert module.operation.verify()
+
+ pm = PassManager()
+ pm.add(convert_bf_to_llvm)
+ pm.add("convert-scf-to-cf, convert-to-llvm")
+
+ pm.run(module.operation)
+
+ ee = ExecutionEngine(module)
+ ee.lookup("bf_init")(0)
+
+
+def run(f):
+ print("TEST:", f.__name__)
+ f()
+
+
+# CHECK: TEST: test_bf
+ at run
+def test_bf():
+ with Context(), Location.unknown():
+ BfDialect.load()
+
+ # CHECK: Hello World!
+ execute(
+ "++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.>---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++."
+ )
>From ffd4421b979830eb6b0c969d9ce91fa00ec8eec1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 27 Feb 2026 20:51:21 +0800
Subject: [PATCH 4/4] refine
---
mlir/test/python/integration/dialects/bf.py | 9 +--------
1 file changed, 1 insertion(+), 8 deletions(-)
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
index 6eb55c3249c46..2dcf4f7a15363 100644
--- a/mlir/test/python/integration/dialects/bf.py
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -18,10 +18,6 @@ class PtrType(BfDialect.Type, name="ptr"):
pass
-class Ptr2Type(BfDialect.Type, name="ptr2"):
- pass
-
-
class NextOp(BfDialect.Operation, name="next"):
in_: Operand[PtrType]
out: Result[PtrType[()]]
@@ -185,10 +181,7 @@ def convert_input(op, adaptor, converter, rewriter):
target = ConversionTarget()
target.add_illegal_dialect(BfDialect)
- config = ConversionConfig()
- config.build_materializations = False
-
- apply_partial_conversion(op, target, patterns.freeze(), config)
+ apply_partial_conversion(op, target, patterns.freeze())
with InsertionPoint(op.opview.body):
func.FuncOp("putchar", FunctionType.get([i32], [i32]), visibility="private")
More information about the Mlir-commits
mailing list