[Mlir-commits] [mlir] [mlir][transform] Make `yield` a `ReturnLike` op. (PR #111408)

Ingo Müller llvmlistbot at llvm.org
Thu Oct 10 01:52:26 PDT 2024


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/111408

>From e7be852053771fb1c7fa84f4105b675e42201d7a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Mon, 7 Oct 2024 17:13:20 +0000
Subject: [PATCH 1/2] [mlir][transform] Make `yield` a `ReturnLike` op.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This PR adds the `ReturnLike` trait to `transform.yield`. This is
required in the one-shot bufferization pass since the merging of
 #110332, which analyses any `FunctionOpInterface` and expects them to
have a `ReturnLike` terminator.

Signed-off-by: Ingo Müller <ingomueller at google.com>
---
 mlir/include/mlir/Dialect/Transform/IR/TransformOps.td | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index b946fc8875860b..d3933cad920a3f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1358,7 +1358,8 @@ def VerifyOp : TransformDialectOp<"verify",
 }
 
 def YieldOp : TransformDialectOp<"yield",
-    [Terminator, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+    [Terminator, ReturnLike,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let summary = "Yields operation handles from a transform IR region";
   let description = [{
     This terminator operation yields operation handles from regions of the

>From d626f95564da4f82313001405437f8cbff30511e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 10 Oct 2024 08:34:04 +0000
Subject: [PATCH 2/2] Overhaul `RegionBranchOpInterface` impls of ops using
 `yield`.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

If we give `transform.yield` the `ReturnLike` trait, then the checks
made by `RegionBranchOpInterface` behave differently. More concretely,
the check that the operands and results passed across the control flow
edges from the parent op to its regions and back are equal (or
compatible). Which results are passed back from regions to the parent op
depend on the terminator and/or whether it is `ReturnLike`.

This commit radically changes the implementations of the `alternatives`,
`foreach`, and `sequence` ops, which are the ops that use `yield` as
their terminator. In fact, all of these ops only ever pass control from
the parent op to the region and from there back to the parent op. In
particular and unlike `scf.for`, `transform.foreach` does *not* pass
control from one iteration of its body to the next directly; it rather
passes the control back to the parent op, which then passes it back to
the body for the next iteration. That can be seen by the fact that the
body always gets arguments of the same type as the operands of the
parent op (and none of the yielded types) and the types that are yielded
correspond exactly to the result types of the parent op.

It is unclear why the previous implementation worked. Clearly, the type
checks have not been executed because the terminator was not
`ReturnLike`, so the mismatching types passed unnoticed. I suppose that
there simply is/was no other use case of using the
`RegionBranchOpInterface` of the affected ops yet.

Signed-off-by: Ingo Müller <ingomueller at google.com>
---
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 34 +++----------------
 1 file changed, 5 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 590cae9aa0d667..48b25d19d7dc31 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -104,16 +104,8 @@ transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
 
 void transform::AlternativesOp::getSuccessorRegions(
     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
-  for (Region &alternative : llvm::drop_begin(
-           getAlternatives(),
-           point.isParent() ? 0
-                            : point.getRegionOrNull()->getRegionNumber() + 1)) {
-    regions.emplace_back(&alternative, !getOperands().empty()
-                                           ? alternative.getArguments()
-                                           : Block::BlockArgListType());
-  }
   if (!point.isParent())
-    regions.emplace_back(getOperation()->getResults());
+    regions.emplace_back(getResults());
 }
 
 void transform::AlternativesOp::getRegionInvocationBounds(
@@ -1502,16 +1494,8 @@ void transform::ForeachOp::getEffects(
 
 void transform::ForeachOp::getSuccessorRegions(
     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
-  Region *bodyRegion = &getBody();
-  if (point.isParent()) {
-    regions.emplace_back(bodyRegion, bodyRegion->getArguments());
-    return;
-  }
-
-  // Branch back to the region or the parent.
-  assert(point == getBody() && "unexpected region index");
-  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
-  regions.emplace_back();
+  if (point.getRegionOrNull() == &getBody())
+    regions.emplace_back(getResults());
 }
 
 OperandRange
@@ -2702,16 +2686,8 @@ transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
 
 void transform::SequenceOp::getSuccessorRegions(
     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
-  if (point.isParent()) {
-    Region *bodyRegion = &getBody();
-    regions.emplace_back(bodyRegion, getNumOperands() != 0
-                                         ? bodyRegion->getArguments()
-                                         : Block::BlockArgListType());
-    return;
-  }
-
-  assert(point == getBody() && "unexpected region index");
-  regions.emplace_back(getOperation()->getResults());
+  if (point.getRegionOrNull() == &getBody())
+    regions.emplace_back(getResults());
 }
 
 void transform::SequenceOp::getRegionInvocationBounds(



More information about the Mlir-commits mailing list