[Mlir-commits] [mlir] [mlir][Interfaces][NFC] Move region loop detection to `RegionBranchOpInterface` (PR #77090)
Matthias Springer
llvmlistbot at llvm.org
Sun Jan 7 04:41:20 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/77090
>From 30f4424753618b33d69e9d4caef8c319984699b7 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 5 Jan 2024 14:10:39 +0000
Subject: [PATCH] [mlir][Interfaces][NFC] Move region loop detection to
`RegionBranchOpInterface`
`BufferPlacementTransformationBase::isLoop` checks if there a loop in the region branching graph of an operation. This algorithm is similar to `isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate code, `isRegionReachable` is generalized, so that it can be used to detect region loops. A helper function `RegionBranchOpInterface::hasLoop` is added.
This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM.
---
.../Bufferization/Transforms/BufferUtils.h | 6 ---
.../mlir/Interfaces/ControlFlowInterfaces.td | 4 ++
.../Transforms/BufferOptimizations.cpp | 28 ++++++++---
.../Bufferization/Transforms/BufferUtils.cpp | 47 -------------------
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 46 +++++++++++++++---
5 files changed, 66 insertions(+), 65 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index 85e9c47ad5302c..8b9c1f0b7282f6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -100,12 +100,6 @@ class BufferPlacementTransformationBase {
return dom;
}
- /// Returns true if the given operation represents a loop by testing whether
- /// it implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`.
- /// In the case of a `RegionBranchOpInterface`, it checks all region-based
- /// control-flow edges for cycles.
- static bool isLoop(Operation *op);
-
/// Constructs a new operation base using the given root operation.
BufferPlacementTransformationBase(Operation *op);
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 120ddf01ebce5c..95ac5dea243aa4 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -272,6 +272,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
/// eventually branch back to the same region. (Maybe after passing through
/// other regions.)
bool isRepetitiveRegion(unsigned index);
+
+ /// Return `true` if there is a loop in the region branching graph. Only
+ /// reachable regions (starting from the entry regions) are considered.
+ bool hasLoop();
}];
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
index 9f5d6a466780ad..9dc2f262a51161 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
@@ -40,6 +40,25 @@ static bool isKnownControlFlowInterface(Operation *op) {
return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);
}
+/// Returns true if the given operation represents a loop by testing whether it
+/// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
+/// the case of a `RegionBranchOpInterface`, it checks all region-based control-
+/// flow edges for cycles.
+static bool isLoop(Operation *op) {
+ // If the operation implements the `LoopLikeOpInterface` it can be considered
+ // a loop.
+ if (isa<LoopLikeOpInterface>(op))
+ return true;
+
+ // If the operation does not implement the `RegionBranchOpInterface`, it is
+ // (currently) not possible to detect a loop.
+ auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
+ if (!regionInterface)
+ return false;
+
+ return regionInterface.hasLoop();
+}
+
/// Returns true if the given operation implements the AllocationOpInterface
/// and it supports the dominate block hoisting.
static bool allowAllocDominateBlockHoisting(Operation *op) {
@@ -115,8 +134,7 @@ static bool hasAllocationScope(Value alloc,
// Check if the operation is a known control flow interface and break the
// loop to avoid transformation in loops. Furthermore skip transformation
// if the operation does not implement a RegionBeanchOpInterface.
- if (BufferPlacementTransformationBase::isLoop(parentOp) ||
- !isKnownControlFlowInterface(parentOp))
+ if (isLoop(parentOp) || !isKnownControlFlowInterface(parentOp))
break;
}
} while ((region = region->getParentRegion()));
@@ -290,9 +308,7 @@ struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {
}
/// Returns true if the given operation does not represent a loop.
- bool isLegalPlacement(Operation *op) {
- return !BufferPlacementTransformationBase::isLoop(op);
- }
+ bool isLegalPlacement(Operation *op) { return !isLoop(op); }
/// Returns true if the given operation should be considered for hoisting.
static bool shouldHoistOpType(Operation *op) {
@@ -327,7 +343,7 @@ struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
/// given loop operation. If this is the case, it indicates that the
/// allocation is passed via a back edge.
bool isLegalPlacement(Operation *op) {
- return BufferPlacementTransformationBase::isLoop(op) &&
+ return isLoop(op) &&
!dominators->dominates(aliasDominatorBlock, op->getBlock());
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 119801f9cc92f3..8fffdbf664c3f4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -96,53 +96,6 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
Operation *op)
: aliases(op), allocs(op), liveness(op) {}
-/// Returns true if the given operation represents a loop by testing whether it
-/// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
-/// the case of a `RegionBranchOpInterface`, it checks all region-based control-
-/// flow edges for cycles.
-bool BufferPlacementTransformationBase::isLoop(Operation *op) {
- // If the operation implements the `LoopLikeOpInterface` it can be considered
- // a loop.
- if (isa<LoopLikeOpInterface>(op))
- return true;
-
- // If the operation does not implement the `RegionBranchOpInterface`, it is
- // (currently) not possible to detect a loop.
- RegionBranchOpInterface regionInterface;
- if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
- return false;
-
- // Recurses into a region using the current region interface to find potential
- // cycles.
- SmallPtrSet<Region *, 4> visitedRegions;
- std::function<bool(Region *)> recurse = [&](Region *current) {
- if (!current)
- return false;
- // If we have found a back edge, the parent operation induces a loop.
- if (!visitedRegions.insert(current).second)
- return true;
- // Recurses into all region successors.
- SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(current, successors);
- for (RegionSuccessor ®ionEntry : successors)
- if (recurse(regionEntry.getSuccessor()))
- return true;
- return false;
- };
-
- // Start with all entry regions and test whether they induce a loop.
- SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
- successorRegions);
- for (RegionSuccessor ®ionEntry : successorRegions) {
- if (recurse(regionEntry.getSuccessor()))
- return true;
- visitedRegions.clear();
- }
-
- return false;
-}
-
//===----------------------------------------------------------------------===//
// BufferPlacementTransformationBase
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index a563ec5cb8db58..6d530ca38e24be 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -219,11 +219,18 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
return success();
}
-/// Return `true` if region `r` is reachable from region `begin` according to
-/// the RegionBranchOpInterface (by taking a branch).
-static bool isRegionReachable(Region *begin, Region *r) {
- assert(begin->getParentOp() == r->getParentOp() &&
- "expected that both regions belong to the same op");
+/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
+/// this function returns "true" for a successor region. The first parameter is
+/// the successor region. The second parameter indicates all already visited
+/// regions.
+using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
+
+/// Traverse the region graph starting at `begin`. The traversal is interrupted
+/// if `stopCondition` evaluates to "true" for a successor region. In that case,
+/// this function returns "true". Otherwise, if the traversal was not
+/// interrupted, this function returns "false".
+static bool traverseRegionGraph(Region *begin,
+ StopConditionFn stopConditionFn) {
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
SmallVector<bool> visited(op->getNumRegions(), false);
visited[begin->getRegionNumber()] = true;
@@ -242,7 +249,7 @@ static bool isRegionReachable(Region *begin, Region *r) {
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
Region *nextRegion = worklist.pop_back_val();
- if (nextRegion == r)
+ if (stopConditionFn(nextRegion, visited))
return true;
if (visited[nextRegion->getRegionNumber()])
continue;
@@ -253,6 +260,18 @@ static bool isRegionReachable(Region *begin, Region *r) {
return false;
}
+/// Return `true` if region `r` is reachable from region `begin` according to
+/// the RegionBranchOpInterface (by taking a branch).
+static bool isRegionReachable(Region *begin, Region *r) {
+ assert(begin->getParentOp() == r->getParentOp() &&
+ "expected that both regions belong to the same op");
+ return traverseRegionGraph(
+ begin, [&](Region *nextRegion, const SmallVector<bool> &visited) {
+ // Interrupt traversal if `r` was reached.
+ return nextRegion == r;
+ });
+}
+
/// Return `true` if `a` and `b` are in mutually exclusive regions.
///
/// 1. Find the first common of `a` and `b` (ancestor) that implements
@@ -306,6 +325,21 @@ bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
return isRegionReachable(region, region);
}
+bool RegionBranchOpInterface::hasLoop() {
+ SmallVector<RegionSuccessor> entryRegions;
+ getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
+ for (RegionSuccessor successor : entryRegions)
+ if (!successor.isParent() &&
+ traverseRegionGraph(
+ successor.getSuccessor(),
+ [](Region *nextRegion, const SmallVector<bool> &visited) {
+ // Interrupt traversal if the region was already visited.
+ return visited[nextRegion->getRegionNumber()];
+ }))
+ return true;
+ return false;
+}
+
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
while (Region *region = op->getParentRegion()) {
op = region->getParentOp();
More information about the Mlir-commits
mailing list