[Mlir-commits] [mlir] [mlir][Interfaces][NFC] Move region loop detection to `RegionBranchOpInterface` (PR #77090)
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 5 04:59:10 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/77090
`BufferPlacementTransformationBase::isLoop` checks if there a loop in the region branching graph of an operation. This algorithm is similar to `isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate code, `isRegionReachable` is generalized, so that it can be used to detect region loops. A helper function `RegionBranchOpInterface::hasLoop` is added.
This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM.
>From 805267fc0799416a0144b1e33e25f801227b8f52 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 5 Jan 2024 12:52:11 +0000
Subject: [PATCH] [mlir][Interfaces][NFC] Move region loop detection to
`RegionBranchOpInterface`
`BufferPlacementTransformationBase::isLoop` checks if there a loop in the region branching graph of an operation. This algorithm is similar to `isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate code, `isRegionReachable` is generalized, so that it can be used to detect region loops. A helper function `RegionBranchOpInterface::hasLoop` is added.
This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM.
---
.../mlir/Interfaces/ControlFlowInterfaces.td | 4 ++
.../Bufferization/Transforms/BufferUtils.cpp | 34 ++-----------
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 49 ++++++++++++++++---
3 files changed, 50 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 120ddf01ebce5c..95ac5dea243aa4 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -272,6 +272,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
/// eventually branch back to the same region. (Maybe after passing through
/// other regions.)
bool isRepetitiveRegion(unsigned index);
+
+ /// Return `true` if there is a loop in the region branching graph. Only
+ /// reachable regions (starting from the entry regions) are considered.
+ bool hasLoop();
}];
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 119801f9cc92f3..227a3df8fb9974 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -108,39 +108,11 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
// If the operation does not implement the `RegionBranchOpInterface`, it is
// (currently) not possible to detect a loop.
- RegionBranchOpInterface regionInterface;
- if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
+ auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
+ if (!regionInterface)
return false;
- // Recurses into a region using the current region interface to find potential
- // cycles.
- SmallPtrSet<Region *, 4> visitedRegions;
- std::function<bool(Region *)> recurse = [&](Region *current) {
- if (!current)
- return false;
- // If we have found a back edge, the parent operation induces a loop.
- if (!visitedRegions.insert(current).second)
- return true;
- // Recurses into all region successors.
- SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(current, successors);
- for (RegionSuccessor ®ionEntry : successors)
- if (recurse(regionEntry.getSuccessor()))
- return true;
- return false;
- };
-
- // Start with all entry regions and test whether they induce a loop.
- SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
- successorRegions);
- for (RegionSuccessor ®ionEntry : successorRegions) {
- if (recurse(regionEntry.getSuccessor()))
- return true;
- visitedRegions.clear();
- }
-
- return false;
+ return regionInterface.hasLoop();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index a563ec5cb8db58..a1ea22dbfc6937 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -219,11 +219,21 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
return success();
}
-/// Return `true` if region `r` is reachable from region `begin` according to
-/// the RegionBranchOpInterface (by taking a branch).
-static bool isRegionReachable(Region *begin, Region *r) {
- assert(begin->getParentOp() == r->getParentOp() &&
- "expected that both regions belong to the same op");
+namespace {
+/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
+/// this function returns "true" for a successor region. The first parameter is
+/// the successor region. The second parameter indicates all already visited
+/// regions.
+using StopConditionFn =
+ std::function<bool(Region *, const SmallVector<bool> &visited)>;
+} // namespace
+
+/// Traverse the region graph starting at `begin`. The traversal is interrupted
+/// if `stopCondition` evaluates to "true" for a successor region. In that case,
+/// this function returns "true". Otherwise, if the traversal was not
+/// interrupted, this function returns "false".
+static bool traverseRegionGraph(Region *begin,
+ StopConditionFn stopConditionFn) {
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
SmallVector<bool> visited(op->getNumRegions(), false);
visited[begin->getRegionNumber()] = true;
@@ -242,7 +252,7 @@ static bool isRegionReachable(Region *begin, Region *r) {
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
Region *nextRegion = worklist.pop_back_val();
- if (nextRegion == r)
+ if (stopConditionFn(nextRegion, visited))
return true;
if (visited[nextRegion->getRegionNumber()])
continue;
@@ -253,6 +263,18 @@ static bool isRegionReachable(Region *begin, Region *r) {
return false;
}
+/// Return `true` if region `r` is reachable from region `begin` according to
+/// the RegionBranchOpInterface (by taking a branch).
+static bool isRegionReachable(Region *begin, Region *r) {
+ assert(begin->getParentOp() == r->getParentOp() &&
+ "expected that both regions belong to the same op");
+ return traverseRegionGraph(
+ begin, [&](Region *nextRegion, const SmallVector<bool> &visited) {
+ // Interrupt traversal if `r` was reached.
+ return nextRegion == r;
+ });
+}
+
/// Return `true` if `a` and `b` are in mutually exclusive regions.
///
/// 1. Find the first common of `a` and `b` (ancestor) that implements
@@ -306,6 +328,21 @@ bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
return isRegionReachable(region, region);
}
+bool RegionBranchOpInterface::hasLoop() {
+ SmallVector<RegionSuccessor> entryRegions;
+ getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
+ for (RegionSuccessor successor : entryRegions)
+ if (!successor.isParent() &&
+ traverseRegionGraph(
+ successor.getSuccessor(),
+ [](Region *nextRegion, const SmallVector<bool> &visited) {
+ // Interrupt traversal if the region was already visited.
+ return visited[nextRegion->getRegionNumber()];
+ }))
+ return true;
+ return false;
+}
+
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
while (Region *region = op->getParentRegion()) {
op = region->getParentOp();
More information about the Mlir-commits
mailing list