[Mlir-commits] [mlir] [MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews (PR #171544)

Rolf Morel llvmlistbot at llvm.org
Sat Dec 13 15:44:09 PST 2025


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/171544

>From afda1db974918ab18c424aa02ea45a060600bec0 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 9 Dec 2025 16:34:18 -0800
Subject: [PATCH 1/2] [MLIR][Transform][Python] Wrapper for transform.foreach

Friendlier wrapper and makes it so that OpResult.owner returns the
relevant OpView instead of Operation (Like OpResultList etc).
---
 mlir/lib/Bindings/Python/IRCore.cpp           |  4 +-
 mlir/python/mlir/dialects/memref.py           |  6 ++-
 .../mlir/dialects/transform/__init__.py       | 50 ++++++++++++++++++
 mlir/test/python/dialects/transform.py        | 52 +++++++++++++++++++
 4 files changed, 108 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2e0c2b895216f..eed12fe4380bb 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1519,12 +1519,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
   static void bindDerived(ClassTy &c) {
     c.def_prop_ro(
         "owner",
-        [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
+        [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
           assert(mlirOperationEqual(self.getParentOperation()->get(),
                                     mlirOpResultGetOwner(self.get())) &&
                  "expected the owner of the value in Python to match that in "
                  "the IR");
-          return self.getParentOperation().getObject();
+          return self.getParentOperation()->createOpView();
         },
         "Returns the operation that produces this result.");
     c.def_prop_ro(
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index bc9a3a52728ad..91185b37a5b5f 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -14,8 +14,10 @@
 def _is_constant_int_like(i):
     return (
         isinstance(i, Value)
-        and isinstance(i.owner, Operation)
-        and isinstance(i.owner.opview, ConstantOp)
+        and (
+            (isinstance(i.owner, Operation) and isinstance(i.owner.opview, ConstantOp))
+            or isinstance(i.owner, ConstantOp)
+        )
         and _is_integer_like_type(i.type)
     )
 
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index b3dd79c7dbd79..fbe4078782997 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -310,6 +310,8 @@ def __init__(
             sym_visibility=sym_visibility,
             arg_attrs=arg_attrs,
             res_attrs=res_attrs,
+            loc=loc,
+            ip=ip,
         )
         self.regions[0].blocks.append(*input_types)
 
@@ -468,6 +470,54 @@ def apply_registered_pass(
     ).result
 
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ForeachOp(ForeachOp):
+    def __init__(
+        self,
+        results: Sequence[Type],
+        targets: Sequence[Union[Operation, Value, OpView]],
+        *,
+        with_zip_shortest: Optional[bool] = False,
+        loc=None,
+        ip=None,
+    ):
+        targets = [_get_op_result_or_value(target) for target in targets]
+        super().__init__(
+            results_=results,
+            targets=targets,
+            with_zip_shortest=with_zip_shortest,
+            loc=loc,
+            ip=ip,
+        )
+        self.regions[0].blocks.append(*[target.type for target in targets])
+
+    @property
+    def body(self) -> Block:
+        return self.regions[0].blocks[0]
+
+    @property
+    def bodyTargets(self) -> BlockArgumentList:
+        return self.regions[0].blocks[0].arguments
+
+
+def foreach(
+    results: Sequence[Type],
+    targets: Sequence[Union[Operation, Value, OpView]],
+    *,
+    with_zip_shortest: Optional[bool] = False,
+    loc=None,
+    ip=None,
+) -> Union[OpResult, OpResultList, ForeachOp]:
+    results = ForeachOp(
+        results=results,
+        targets=targets,
+        with_zip_shortest=with_zip_shortest,
+        loc=loc,
+        ip=ip,
+    ).results
+    return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
+
+
 AnyOpTypeT = NewType("AnyOpType", AnyOpType)
 
 
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index f58442d04fc66..dfcc890b83ffc 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -401,3 +401,55 @@ def testApplyRegisteredPassOp(module: Module):
             options={"exclude": (symbol_a, symbol_b)},
         )
         transform.YieldOp()
+
+
+# CHECK-LABEL: TEST: testForeachOp
+ at run
+def testForeachOp(module: Module):
+    # CHECK: transform.sequence
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [transform.AnyOpType.get()],
+        transform.AnyOpType.get(),
+    )
+    with InsertionPoint(sequence.body):
+        # CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
+        foreach1 = transform.ForeachOp(
+            (transform.AnyOpType.get(),), (sequence.bodyTarget,)
+        )
+        with InsertionPoint(foreach1.body):
+            # CHECK: transform.yield {{.*}} : !transform.any_op
+            transform.yield_(foreach1.bodyTargets)
+
+        a_val = transform.get_operand(
+            transform.AnyValueType.get(), foreach1.result, [0]
+        )
+        a_param = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("a_param")
+        )
+
+        # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
+        foreach2 = transform.foreach(
+            (transform.AnyValueType.get(), transform.AnyParamType.get()),
+            (sequence.bodyTarget, a_val, a_param),
+        )
+        with InsertionPoint(foreach2.owner.body):
+            # CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
+            transform.yield_(foreach2.owner.bodyTargets[1:3])
+
+        another_param = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("another_param")
+        )
+        params = transform.merge_handles([a_param, another_param])
+
+        # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
+        foreach3 = transform.foreach(
+            (transform.AnyOpType.get(),),
+            (foreach1.result, foreach2[1], params),
+            with_zip_shortest=True,
+        )
+        with InsertionPoint(foreach3.owner.body):
+            # CHECK: transform.yield {{.*}} : !transform.any_op
+            transform.yield_((foreach3.owner.bodyTargets[0],))
+
+        transform.yield_((foreach3,))

>From ef69bcc47f2ff2dc67fc21187e8419f0a90cdcf2 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 13 Dec 2025 15:42:45 -0800
Subject: [PATCH 2/2] Per feedback, also update Value.owner to return an OpView

---
 mlir/lib/Bindings/Python/IRCore.cpp | 4 ++--
 mlir/python/mlir/dialects/memref.py | 5 +----
 2 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index eed12fe4380bb..da33945f42913 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -4638,7 +4638,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           kDumpDocstring)
       .def_prop_ro(
           "owner",
-          [](PyValue &self) -> nb::object {
+          [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
             MlirValue v = self.get();
             if (mlirValueIsAOpResult(v)) {
               assert(mlirOperationEqual(self.getParentOperation()->get(),
@@ -4646,7 +4646,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                      "expected the owner of the value in Python to match "
                      "that in "
                      "the IR");
-              return self.getParentOperation().getObject();
+              return self.getParentOperation()->createOpView();
             }
 
             if (mlirValueIsABlockArgument(v)) {
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 91185b37a5b5f..c80a1b1a89358 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -14,10 +14,7 @@
 def _is_constant_int_like(i):
     return (
         isinstance(i, Value)
-        and (
-            (isinstance(i.owner, Operation) and isinstance(i.owner.opview, ConstantOp))
-            or isinstance(i.owner, ConstantOp)
-        )
+        and isinstance(i.owner, ConstantOp)
         and _is_integer_like_type(i.type)
     )
 



More information about the Mlir-commits mailing list