[Mlir-commits] [mlir] [MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews (PR #172228)

Rolf Morel llvmlistbot at llvm.org
Sun Dec 14 13:53:12 PST 2025


https://github.com/rolfmorel created https://github.com/llvm/llvm-project/pull/172228

Friendlier wrapper for transform.foreach.

To facilitate that friendliness, makes it so that OpResult.owner returns the relevant OpView instead of Operation. For good measure, also changes Value.owner to return OpView instead of Operation, thereby ensuring consistency. That is, makes it is so that all op-returning .owner accessors return OpView (and thereby give access to all goodies available on registered OpViews.)

Reland of https://github.com/llvm/llvm-project/pull/171544 due to fixup for integration test.

>From e96870e604d05ddf26410c1c416783ea9dd4998c Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 14 Dec 2025 13:48:09 -0800
Subject: [PATCH 1/2] Fix integration test

---
 mlir/test/python/integration/dialects/pdl.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index fe27dd4203a21..6a377a090fbb9 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -174,7 +174,7 @@ def add_fold(rewriter, results, values):
 
     def is_zero(value):
         op = value.owner
-        if isinstance(op, Operation):
+        if isinstance(op, OpView):
             return op.name == "myint.constant" and op.attributes["value"].value == 0
         return False
 

>From 187e28fab265164a556bc70753f1de8e467ee876 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 14 Dec 2025 13:49:37 -0800
Subject: [PATCH 2/2] Reapply "[MLIR][Transform][Python] transform.foreach
 wrapper and .owner OpViews" (#172225)

This reverts commit b9fe6532a70c58a4d73b59de88d508a651d3abc9.
---
 mlir/lib/Bindings/Python/IRCore.cpp           |  8 +--
 mlir/python/mlir/dialects/memref.py           |  3 +-
 .../mlir/dialects/transform/__init__.py       | 50 ++++++++++++++++++
 mlir/test/python/dialects/transform.py        | 52 +++++++++++++++++++
 4 files changed, 107 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b0de14719ab61..168c57955af07 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1519,12 +1519,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
   static void bindDerived(ClassTy &c) {
     c.def_prop_ro(
         "owner",
-        [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
+        [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
           assert(mlirOperationEqual(self.getParentOperation()->get(),
                                     mlirOpResultGetOwner(self.get())) &&
                  "expected the owner of the value in Python to match that in "
                  "the IR");
-          return self.getParentOperation().getObject();
+          return self.getParentOperation()->createOpView();
         },
         "Returns the operation that produces this result.");
     c.def_prop_ro(
@@ -4646,7 +4646,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           kDumpDocstring)
       .def_prop_ro(
           "owner",
-          [](PyValue &self) -> nb::object {
+          [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
             MlirValue v = self.get();
             if (mlirValueIsAOpResult(v)) {
               assert(mlirOperationEqual(self.getParentOperation()->get(),
@@ -4654,7 +4654,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                      "expected the owner of the value in Python to match "
                      "that in "
                      "the IR");
-              return self.getParentOperation().getObject();
+              return self.getParentOperation()->createOpView();
             }
 
             if (mlirValueIsABlockArgument(v)) {
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index bc9a3a52728ad..c80a1b1a89358 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -14,8 +14,7 @@
 def _is_constant_int_like(i):
     return (
         isinstance(i, Value)
-        and isinstance(i.owner, Operation)
-        and isinstance(i.owner.opview, ConstantOp)
+        and isinstance(i.owner, ConstantOp)
         and _is_integer_like_type(i.type)
     )
 
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index b3dd79c7dbd79..fbe4078782997 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -310,6 +310,8 @@ def __init__(
             sym_visibility=sym_visibility,
             arg_attrs=arg_attrs,
             res_attrs=res_attrs,
+            loc=loc,
+            ip=ip,
         )
         self.regions[0].blocks.append(*input_types)
 
@@ -468,6 +470,54 @@ def apply_registered_pass(
     ).result
 
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ForeachOp(ForeachOp):
+    def __init__(
+        self,
+        results: Sequence[Type],
+        targets: Sequence[Union[Operation, Value, OpView]],
+        *,
+        with_zip_shortest: Optional[bool] = False,
+        loc=None,
+        ip=None,
+    ):
+        targets = [_get_op_result_or_value(target) for target in targets]
+        super().__init__(
+            results_=results,
+            targets=targets,
+            with_zip_shortest=with_zip_shortest,
+            loc=loc,
+            ip=ip,
+        )
+        self.regions[0].blocks.append(*[target.type for target in targets])
+
+    @property
+    def body(self) -> Block:
+        return self.regions[0].blocks[0]
+
+    @property
+    def bodyTargets(self) -> BlockArgumentList:
+        return self.regions[0].blocks[0].arguments
+
+
+def foreach(
+    results: Sequence[Type],
+    targets: Sequence[Union[Operation, Value, OpView]],
+    *,
+    with_zip_shortest: Optional[bool] = False,
+    loc=None,
+    ip=None,
+) -> Union[OpResult, OpResultList, ForeachOp]:
+    results = ForeachOp(
+        results=results,
+        targets=targets,
+        with_zip_shortest=with_zip_shortest,
+        loc=loc,
+        ip=ip,
+    ).results
+    return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
+
+
 AnyOpTypeT = NewType("AnyOpType", AnyOpType)
 
 
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index f58442d04fc66..dfcc890b83ffc 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -401,3 +401,55 @@ def testApplyRegisteredPassOp(module: Module):
             options={"exclude": (symbol_a, symbol_b)},
         )
         transform.YieldOp()
+
+
+# CHECK-LABEL: TEST: testForeachOp
+ at run
+def testForeachOp(module: Module):
+    # CHECK: transform.sequence
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [transform.AnyOpType.get()],
+        transform.AnyOpType.get(),
+    )
+    with InsertionPoint(sequence.body):
+        # CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
+        foreach1 = transform.ForeachOp(
+            (transform.AnyOpType.get(),), (sequence.bodyTarget,)
+        )
+        with InsertionPoint(foreach1.body):
+            # CHECK: transform.yield {{.*}} : !transform.any_op
+            transform.yield_(foreach1.bodyTargets)
+
+        a_val = transform.get_operand(
+            transform.AnyValueType.get(), foreach1.result, [0]
+        )
+        a_param = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("a_param")
+        )
+
+        # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
+        foreach2 = transform.foreach(
+            (transform.AnyValueType.get(), transform.AnyParamType.get()),
+            (sequence.bodyTarget, a_val, a_param),
+        )
+        with InsertionPoint(foreach2.owner.body):
+            # CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
+            transform.yield_(foreach2.owner.bodyTargets[1:3])
+
+        another_param = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("another_param")
+        )
+        params = transform.merge_handles([a_param, another_param])
+
+        # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
+        foreach3 = transform.foreach(
+            (transform.AnyOpType.get(),),
+            (foreach1.result, foreach2[1], params),
+            with_zip_shortest=True,
+        )
+        with InsertionPoint(foreach3.owner.body):
+            # CHECK: transform.yield {{.*}} : !transform.any_op
+            transform.yield_((foreach3.owner.bodyTargets[0],))
+
+        transform.yield_((foreach3,))



More information about the Mlir-commits mailing list