[Mlir-commits] [mlir] [MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews (PR #171544)
Rolf Morel
llvmlistbot at llvm.org
Sat Dec 13 15:44:09 PST 2025
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/171544
>From afda1db974918ab18c424aa02ea45a060600bec0 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 9 Dec 2025 16:34:18 -0800
Subject: [PATCH 1/2] [MLIR][Transform][Python] Wrapper for transform.foreach
Friendlier wrapper and makes it so that OpResult.owner returns the
relevant OpView instead of Operation (Like OpResultList etc).
---
mlir/lib/Bindings/Python/IRCore.cpp | 4 +-
mlir/python/mlir/dialects/memref.py | 6 ++-
.../mlir/dialects/transform/__init__.py | 50 ++++++++++++++++++
mlir/test/python/dialects/transform.py | 52 +++++++++++++++++++
4 files changed, 108 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2e0c2b895216f..eed12fe4380bb 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1519,12 +1519,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
static void bindDerived(ClassTy &c) {
c.def_prop_ro(
"owner",
- [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
+ [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match that in "
"the IR");
- return self.getParentOperation().getObject();
+ return self.getParentOperation()->createOpView();
},
"Returns the operation that produces this result.");
c.def_prop_ro(
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index bc9a3a52728ad..91185b37a5b5f 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -14,8 +14,10 @@
def _is_constant_int_like(i):
return (
isinstance(i, Value)
- and isinstance(i.owner, Operation)
- and isinstance(i.owner.opview, ConstantOp)
+ and (
+ (isinstance(i.owner, Operation) and isinstance(i.owner.opview, ConstantOp))
+ or isinstance(i.owner, ConstantOp)
+ )
and _is_integer_like_type(i.type)
)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index b3dd79c7dbd79..fbe4078782997 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -310,6 +310,8 @@ def __init__(
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
+ loc=loc,
+ ip=ip,
)
self.regions[0].blocks.append(*input_types)
@@ -468,6 +470,54 @@ def apply_registered_pass(
).result
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ForeachOp(ForeachOp):
+ def __init__(
+ self,
+ results: Sequence[Type],
+ targets: Sequence[Union[Operation, Value, OpView]],
+ *,
+ with_zip_shortest: Optional[bool] = False,
+ loc=None,
+ ip=None,
+ ):
+ targets = [_get_op_result_or_value(target) for target in targets]
+ super().__init__(
+ results_=results,
+ targets=targets,
+ with_zip_shortest=with_zip_shortest,
+ loc=loc,
+ ip=ip,
+ )
+ self.regions[0].blocks.append(*[target.type for target in targets])
+
+ @property
+ def body(self) -> Block:
+ return self.regions[0].blocks[0]
+
+ @property
+ def bodyTargets(self) -> BlockArgumentList:
+ return self.regions[0].blocks[0].arguments
+
+
+def foreach(
+ results: Sequence[Type],
+ targets: Sequence[Union[Operation, Value, OpView]],
+ *,
+ with_zip_shortest: Optional[bool] = False,
+ loc=None,
+ ip=None,
+) -> Union[OpResult, OpResultList, ForeachOp]:
+ results = ForeachOp(
+ results=results,
+ targets=targets,
+ with_zip_shortest=with_zip_shortest,
+ loc=loc,
+ ip=ip,
+ ).results
+ return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
+
+
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index f58442d04fc66..dfcc890b83ffc 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -401,3 +401,55 @@ def testApplyRegisteredPassOp(module: Module):
options={"exclude": (symbol_a, symbol_b)},
)
transform.YieldOp()
+
+
+# CHECK-LABEL: TEST: testForeachOp
+ at run
+def testForeachOp(module: Module):
+ # CHECK: transform.sequence
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [transform.AnyOpType.get()],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
+ # CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
+ foreach1 = transform.ForeachOp(
+ (transform.AnyOpType.get(),), (sequence.bodyTarget,)
+ )
+ with InsertionPoint(foreach1.body):
+ # CHECK: transform.yield {{.*}} : !transform.any_op
+ transform.yield_(foreach1.bodyTargets)
+
+ a_val = transform.get_operand(
+ transform.AnyValueType.get(), foreach1.result, [0]
+ )
+ a_param = transform.param_constant(
+ transform.AnyParamType.get(), StringAttr.get("a_param")
+ )
+
+ # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
+ foreach2 = transform.foreach(
+ (transform.AnyValueType.get(), transform.AnyParamType.get()),
+ (sequence.bodyTarget, a_val, a_param),
+ )
+ with InsertionPoint(foreach2.owner.body):
+ # CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
+ transform.yield_(foreach2.owner.bodyTargets[1:3])
+
+ another_param = transform.param_constant(
+ transform.AnyParamType.get(), StringAttr.get("another_param")
+ )
+ params = transform.merge_handles([a_param, another_param])
+
+ # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
+ foreach3 = transform.foreach(
+ (transform.AnyOpType.get(),),
+ (foreach1.result, foreach2[1], params),
+ with_zip_shortest=True,
+ )
+ with InsertionPoint(foreach3.owner.body):
+ # CHECK: transform.yield {{.*}} : !transform.any_op
+ transform.yield_((foreach3.owner.bodyTargets[0],))
+
+ transform.yield_((foreach3,))
>From ef69bcc47f2ff2dc67fc21187e8419f0a90cdcf2 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 13 Dec 2025 15:42:45 -0800
Subject: [PATCH 2/2] Per feedback, also update Value.owner to return an OpView
---
mlir/lib/Bindings/Python/IRCore.cpp | 4 ++--
mlir/python/mlir/dialects/memref.py | 5 +----
2 files changed, 3 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index eed12fe4380bb..da33945f42913 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -4638,7 +4638,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kDumpDocstring)
.def_prop_ro(
"owner",
- [](PyValue &self) -> nb::object {
+ [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
MlirValue v = self.get();
if (mlirValueIsAOpResult(v)) {
assert(mlirOperationEqual(self.getParentOperation()->get(),
@@ -4646,7 +4646,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"expected the owner of the value in Python to match "
"that in "
"the IR");
- return self.getParentOperation().getObject();
+ return self.getParentOperation()->createOpView();
}
if (mlirValueIsABlockArgument(v)) {
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 91185b37a5b5f..c80a1b1a89358 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -14,10 +14,7 @@
def _is_constant_int_like(i):
return (
isinstance(i, Value)
- and (
- (isinstance(i.owner, Operation) and isinstance(i.owner.opview, ConstantOp))
- or isinstance(i.owner, ConstantOp)
- )
+ and isinstance(i.owner, ConstantOp)
and _is_integer_like_type(i.type)
)
More information about the Mlir-commits
mailing list