[llvm-branch-commits] [mlir] db308ed - Revert "[MLIR][Transform][Python] transform.foreach wrapper and .owner OpView…"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Dec 14 13:11:17 PST 2025
Author: Mehdi Amini
Date: 2025-12-14T22:11:13+01:00
New Revision: db308edbab49611d3f1f15d3d77d43a4457007f3
URL: https://github.com/llvm/llvm-project/commit/db308edbab49611d3f1f15d3d77d43a4457007f3
DIFF: https://github.com/llvm/llvm-project/commit/db308edbab49611d3f1f15d3d77d43a4457007f3.diff
LOG: Revert "[MLIR][Transform][Python] transform.foreach wrapper and .owner OpView…"
This reverts commit 4cdec92827e6901e077e7f50a382d6acabe7aaf0.
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/python/mlir/dialects/memref.py
mlir/python/mlir/dialects/transform/__init__.py
mlir/test/python/dialects/transform.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 168c57955af07..b0de14719ab61 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1519,12 +1519,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
static void bindDerived(ClassTy &c) {
c.def_prop_ro(
"owner",
- [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
+ [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match that in "
"the IR");
- return self.getParentOperation()->createOpView();
+ return self.getParentOperation().getObject();
},
"Returns the operation that produces this result.");
c.def_prop_ro(
@@ -4646,7 +4646,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kDumpDocstring)
.def_prop_ro(
"owner",
- [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
+ [](PyValue &self) -> nb::object {
MlirValue v = self.get();
if (mlirValueIsAOpResult(v)) {
assert(mlirOperationEqual(self.getParentOperation()->get(),
@@ -4654,7 +4654,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"expected the owner of the value in Python to match "
"that in "
"the IR");
- return self.getParentOperation()->createOpView();
+ return self.getParentOperation().getObject();
}
if (mlirValueIsABlockArgument(v)) {
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index c80a1b1a89358..bc9a3a52728ad 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -14,7 +14,8 @@
def _is_constant_int_like(i):
return (
isinstance(i, Value)
- and isinstance(i.owner, ConstantOp)
+ and isinstance(i.owner, Operation)
+ and isinstance(i.owner.opview, ConstantOp)
and _is_integer_like_type(i.type)
)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index fbe4078782997..b3dd79c7dbd79 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -310,8 +310,6 @@ def __init__(
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
- loc=loc,
- ip=ip,
)
self.regions[0].blocks.append(*input_types)
@@ -470,54 +468,6 @@ def apply_registered_pass(
).result
- at _ods_cext.register_operation(_Dialect, replace=True)
-class ForeachOp(ForeachOp):
- def __init__(
- self,
- results: Sequence[Type],
- targets: Sequence[Union[Operation, Value, OpView]],
- *,
- with_zip_shortest: Optional[bool] = False,
- loc=None,
- ip=None,
- ):
- targets = [_get_op_result_or_value(target) for target in targets]
- super().__init__(
- results_=results,
- targets=targets,
- with_zip_shortest=with_zip_shortest,
- loc=loc,
- ip=ip,
- )
- self.regions[0].blocks.append(*[target.type for target in targets])
-
- @property
- def body(self) -> Block:
- return self.regions[0].blocks[0]
-
- @property
- def bodyTargets(self) -> BlockArgumentList:
- return self.regions[0].blocks[0].arguments
-
-
-def foreach(
- results: Sequence[Type],
- targets: Sequence[Union[Operation, Value, OpView]],
- *,
- with_zip_shortest: Optional[bool] = False,
- loc=None,
- ip=None,
-) -> Union[OpResult, OpResultList, ForeachOp]:
- results = ForeachOp(
- results=results,
- targets=targets,
- with_zip_shortest=with_zip_shortest,
- loc=loc,
- ip=ip,
- ).results
- return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
-
-
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index dfcc890b83ffc..f58442d04fc66 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -401,55 +401,3 @@ def testApplyRegisteredPassOp(module: Module):
options={"exclude": (symbol_a, symbol_b)},
)
transform.YieldOp()
-
-
-# CHECK-LABEL: TEST: testForeachOp
- at run
-def testForeachOp(module: Module):
- # CHECK: transform.sequence
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [transform.AnyOpType.get()],
- transform.AnyOpType.get(),
- )
- with InsertionPoint(sequence.body):
- # CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
- foreach1 = transform.ForeachOp(
- (transform.AnyOpType.get(),), (sequence.bodyTarget,)
- )
- with InsertionPoint(foreach1.body):
- # CHECK: transform.yield {{.*}} : !transform.any_op
- transform.yield_(foreach1.bodyTargets)
-
- a_val = transform.get_operand(
- transform.AnyValueType.get(), foreach1.result, [0]
- )
- a_param = transform.param_constant(
- transform.AnyParamType.get(), StringAttr.get("a_param")
- )
-
- # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
- foreach2 = transform.foreach(
- (transform.AnyValueType.get(), transform.AnyParamType.get()),
- (sequence.bodyTarget, a_val, a_param),
- )
- with InsertionPoint(foreach2.owner.body):
- # CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
- transform.yield_(foreach2.owner.bodyTargets[1:3])
-
- another_param = transform.param_constant(
- transform.AnyParamType.get(), StringAttr.get("another_param")
- )
- params = transform.merge_handles([a_param, another_param])
-
- # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
- foreach3 = transform.foreach(
- (transform.AnyOpType.get(),),
- (foreach1.result, foreach2[1], params),
- with_zip_shortest=True,
- )
- with InsertionPoint(foreach3.owner.body):
- # CHECK: transform.yield {{.*}} : !transform.any_op
- transform.yield_((foreach3.owner.bodyTargets[0],))
-
- transform.yield_((foreach3,))
More information about the llvm-branch-commits
mailing list