[Mlir-commits] [mlir] [MLIR][Python] Expose the insertion point of pattern rewriter (PR #161001)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 2 23:49:09 PDT 2025
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/161001
>From 28d65b8d5e0a059f790ff2a56423ab9e813c5e72 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 27 Sep 2025 18:31:40 +0800
Subject: [PATCH 1/4] [MLIR][Python] Expose the insertion point of pattern
rewriter
---
mlir/include/mlir-c/Rewrite.h | 11 +++
mlir/lib/Bindings/Python/Rewrite.cpp | 16 ++++-
mlir/lib/CAPI/Transforms/Rewrite.cpp | 16 +++++
mlir/test/python/integration/dialects/pdl.py | 76 +++++++++++++++++++-
4 files changed, 116 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 77be1f480eacf..b0f60901c5301 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -101,6 +101,9 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
MLIR_CAPI_EXPORTED MlirBlock
mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);
+
//===----------------------------------------------------------------------===//
/// Block and operation creation/insertion/cloning
//===----------------------------------------------------------------------===//
@@ -310,6 +313,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
+//===----------------------------------------------------------------------===//
+/// PatternRewriter API
+//===----------------------------------------------------------------------===//
+
+/// Cast the PatternRewriter to a RewriterBase
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 20392b9002706..b520d8d3f1ecc 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -143,7 +143,21 @@ class PyFrozenRewritePatternSet {
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
- nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
+ nb::class_<MlirPatternRewriter>(m, "PatternRewriter")
+ .def("ip", [](MlirPatternRewriter rewriter) {
+ MlirRewriterBase base = mlirPatternRewriterAsBase(rewriter);
+ MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
+ MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
+ MlirOperation owner = mlirBlockGetParentOperation(block);
+ auto ctx = PyMlirContext::forContext(mlirRewriterBaseGetContext(base))
+ ->getRef();
+ if (mlirOperationIsNull(op)) {
+ auto parent = PyOperation::forOperation(ctx, owner);
+ return PyInsertionPoint(PyBlock(parent, block));
+ }
+
+ return PyInsertionPoint(*PyOperation::forOperation(ctx, op).get());
+ });
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 8ee6308cadf83..b149d35f0d88b 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -70,6 +70,18 @@ MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
return wrap(unwrap(rewriter)->getBlock());
}
+MlirOperation
+mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) {
+ mlir::RewriterBase *base = unwrap(rewriter);
+ mlir::Block *block = base->getInsertionBlock();
+ auto it = base->getInsertionPoint();
+ if (it == block->end()) {
+ return {nullptr};
+ }
+
+ return wrap(std::addressof(*it));
+}
+
//===----------------------------------------------------------------------===//
/// Block and operation creation/insertion/cloning
//===----------------------------------------------------------------------===//
@@ -316,6 +328,10 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
return {rewriter};
}
+MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
+ return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index c8e6197e03842..b8c7e277f1776 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -121,8 +121,10 @@ def load_myint_dialect():
# This PDL pattern is to fold constant additions,
-# i.e. add(constant0, constant1) -> constant2
-# where constant2 = constant0 + constant1.
+# including two patterns:
+# 1. add(constant0, constant1) -> constant2
+# where constant2 = constant0 + constant1;
+# 2. add(x, 0) or add(0, x) -> x.
def get_pdl_pattern_fold():
m = Module.create()
i32 = IntegerType.get_signless(32)
@@ -237,3 +239,73 @@ def test_pdl_register_function_constraint(module_):
apply_patterns_and_fold_greedily(module_, frozen)
return module_
+
+
+# This pattern is to expand constant to additions
+# unless the constant is no more than 1,
+# e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
+def get_pdl_pattern_expand():
+ m = Module.create()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(m.body):
+
+ @pdl.pattern(benefit=1, sym_name="myint_constant_expand")
+ def pat():
+ t = pdl.TypeOp(i32)
+ cst = pdl.AttributeOp()
+ pdl.apply_native_constraint([], "is_one", [cst])
+ op0 = pdl.OperationOp(name="myint.constant", attributes={"value": cst}, types=[t])
+
+ @pdl.rewrite()
+ def rew():
+ expanded = pdl.apply_native_rewrite([pdl.OperationType.get()], "expand", [cst])
+ pdl.ReplaceOp(op0, with_op=expanded)
+
+ def is_one(rewriter, results, values):
+ cst = values[0].value
+ return cst <= 1
+
+ def expand(rewriter, results, values):
+ cst = values[0].value
+ c1 = cst // 2
+ c2 = cst - c1
+ with rewriter.ip():
+ op1 = Operation.create("myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c1)})
+ op2 = Operation.create("myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c2)})
+ res = Operation.create("myint.add", results=[i32], operands=[op1.result, op2.result])
+ results.append(res)
+
+ pdl_module = PDLModule(m)
+ pdl_module.register_constraint_function("is_one", is_one)
+ pdl_module.register_rewrite_function("expand", expand)
+ return pdl_module.freeze()
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function_expand
+# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
+# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
+# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
+# CHECK: return %8 : i32
+ at construct_and_print_in_module
+def test_pdl_register_function_expand(module_):
+ load_myint_dialect()
+
+ module_ = Module.parse(
+ """
+ func.func @f() -> i32 {
+ %0 = "myint.constant"() { value = 5 }: () -> (i32)
+ return %0 : i32
+ }
+ """
+ )
+
+ frozen = get_pdl_pattern_expand()
+ apply_patterns_and_fold_greedily(module_, frozen)
+
+ return module_
>From beab53db10c980b9a326abc1c1bb3ca73b4bbddd Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 27 Sep 2025 19:16:43 +0800
Subject: [PATCH 2/4] format
---
mlir/test/python/integration/dialects/pdl.py | 27 +++++++++++++++-----
1 file changed, 21 insertions(+), 6 deletions(-)
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index b8c7e277f1776..752d213673a70 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -16,6 +16,7 @@ def construct_and_print_in_module(f):
print(module)
return f
+
def get_pdl_patterns():
# Create a rewrite from add to mul. This will match
# - operation name is arith.addi
@@ -254,25 +255,39 @@ def pat():
t = pdl.TypeOp(i32)
cst = pdl.AttributeOp()
pdl.apply_native_constraint([], "is_one", [cst])
- op0 = pdl.OperationOp(name="myint.constant", attributes={"value": cst}, types=[t])
+ op0 = pdl.OperationOp(
+ name="myint.constant", attributes={"value": cst}, types=[t]
+ )
@pdl.rewrite()
def rew():
- expanded = pdl.apply_native_rewrite([pdl.OperationType.get()], "expand", [cst])
+ expanded = pdl.apply_native_rewrite(
+ [pdl.OperationType.get()], "expand", [cst]
+ )
pdl.ReplaceOp(op0, with_op=expanded)
def is_one(rewriter, results, values):
cst = values[0].value
return cst <= 1
-
+
def expand(rewriter, results, values):
cst = values[0].value
c1 = cst // 2
c2 = cst - c1
with rewriter.ip():
- op1 = Operation.create("myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c1)})
- op2 = Operation.create("myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c2)})
- res = Operation.create("myint.add", results=[i32], operands=[op1.result, op2.result])
+ op1 = Operation.create(
+ "myint.constant",
+ results=[i32],
+ attributes={"value": IntegerAttr.get(i32, c1)},
+ )
+ op2 = Operation.create(
+ "myint.constant",
+ results=[i32],
+ attributes={"value": IntegerAttr.get(i32, c2)},
+ )
+ res = Operation.create(
+ "myint.add", results=[i32], operands=[op1.result, op2.result]
+ )
results.append(res)
pdl_module = PDLModule(m)
>From 68fbb0f18f8e86d51f16c4d6d9f4936133ef6d13 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 27 Sep 2025 22:20:11 +0800
Subject: [PATCH 3/4] add comment for c api
---
mlir/include/mlir-c/Rewrite.h | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index b0f60901c5301..c53470ca09960 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -101,6 +101,9 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
MLIR_CAPI_EXPORTED MlirBlock
mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
+/// Returns the operation right after the current insertion point
+/// of the rewriter. A null MlirOperation will be returned
+// if the current insertion block is empty.
MLIR_CAPI_EXPORTED MlirOperation
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);
>From 02f17827662ae3f5fd6fc5ff498aa3196bfb97c1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 27 Sep 2025 22:27:29 +0800
Subject: [PATCH 4/4] fix doc
---
mlir/include/mlir-c/Rewrite.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index c53470ca09960..5dd285ee076c4 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -103,7 +103,7 @@ mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
/// Returns the operation right after the current insertion point
/// of the rewriter. A null MlirOperation will be returned
-// if the current insertion block is empty.
+// if the current insertion point is at the end of the block.
MLIR_CAPI_EXPORTED MlirOperation
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);
More information about the Mlir-commits
mailing list