[Mlir-commits] [mlir] [mlir][Interfaces][NFC] Move region loop detection to `RegionBranchOpInterface` (PR #77090)

Matthias Springer llvmlistbot at llvm.org
Fri Jan 5 04:59:10 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/77090

`BufferPlacementTransformationBase::isLoop` checks if there a loop in the region branching graph of an operation. This algorithm is similar to `isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate code, `isRegionReachable` is generalized, so that it can be used to detect region loops. A helper function `RegionBranchOpInterface::hasLoop` is added.

This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM.

>From 805267fc0799416a0144b1e33e25f801227b8f52 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 5 Jan 2024 12:52:11 +0000
Subject: [PATCH] [mlir][Interfaces][NFC] Move region loop detection to
 `RegionBranchOpInterface`

`BufferPlacementTransformationBase::isLoop` checks if there a loop in the region branching graph of an operation. This algorithm is similar to `isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate code, `isRegionReachable` is generalized, so that it can be used to detect region loops. A helper function `RegionBranchOpInterface::hasLoop` is added.

This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM.
---
 .../mlir/Interfaces/ControlFlowInterfaces.td  |  4 ++
 .../Bufferization/Transforms/BufferUtils.cpp  | 34 ++-----------
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 49 ++++++++++++++++---
 3 files changed, 50 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 120ddf01ebce5c..95ac5dea243aa4 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -272,6 +272,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     /// eventually branch back to the same region. (Maybe after passing through
     /// other regions.)
     bool isRepetitiveRegion(unsigned index);
+
+    /// Return `true` if there is a loop in the region branching graph. Only
+    /// reachable regions (starting from the entry regions) are considered.
+    bool hasLoop();
   }];
 }
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 119801f9cc92f3..227a3df8fb9974 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -108,39 +108,11 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
 
   // If the operation does not implement the `RegionBranchOpInterface`, it is
   // (currently) not possible to detect a loop.
-  RegionBranchOpInterface regionInterface;
-  if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
+  auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
+  if (!regionInterface)
     return false;
 
-  // Recurses into a region using the current region interface to find potential
-  // cycles.
-  SmallPtrSet<Region *, 4> visitedRegions;
-  std::function<bool(Region *)> recurse = [&](Region *current) {
-    if (!current)
-      return false;
-    // If we have found a back edge, the parent operation induces a loop.
-    if (!visitedRegions.insert(current).second)
-      return true;
-    // Recurses into all region successors.
-    SmallVector<RegionSuccessor, 2> successors;
-    regionInterface.getSuccessorRegions(current, successors);
-    for (RegionSuccessor &regionEntry : successors)
-      if (recurse(regionEntry.getSuccessor()))
-        return true;
-    return false;
-  };
-
-  // Start with all entry regions and test whether they induce a loop.
-  SmallVector<RegionSuccessor, 2> successorRegions;
-  regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
-                                      successorRegions);
-  for (RegionSuccessor &regionEntry : successorRegions) {
-    if (recurse(regionEntry.getSuccessor()))
-      return true;
-    visitedRegions.clear();
-  }
-
-  return false;
+  return regionInterface.hasLoop();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index a563ec5cb8db58..a1ea22dbfc6937 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -219,11 +219,21 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   return success();
 }
 
-/// Return `true` if region `r` is reachable from region `begin` according to
-/// the RegionBranchOpInterface (by taking a branch).
-static bool isRegionReachable(Region *begin, Region *r) {
-  assert(begin->getParentOp() == r->getParentOp() &&
-         "expected that both regions belong to the same op");
+namespace {
+/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
+/// this function returns "true" for a successor region. The first parameter is
+/// the successor region. The second parameter indicates all already visited
+/// regions.
+using StopConditionFn =
+    std::function<bool(Region *, const SmallVector<bool> &visited)>;
+} // namespace
+
+/// Traverse the region graph starting at `begin`. The traversal is interrupted
+/// if `stopCondition` evaluates to "true" for a successor region. In that case,
+/// this function returns "true". Otherwise, if the traversal was not
+/// interrupted, this function returns "false".
+static bool traverseRegionGraph(Region *begin,
+                                StopConditionFn stopConditionFn) {
   auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
   SmallVector<bool> visited(op->getNumRegions(), false);
   visited[begin->getRegionNumber()] = true;
@@ -242,7 +252,7 @@ static bool isRegionReachable(Region *begin, Region *r) {
   // Process all regions in the worklist via DFS.
   while (!worklist.empty()) {
     Region *nextRegion = worklist.pop_back_val();
-    if (nextRegion == r)
+    if (stopConditionFn(nextRegion, visited))
       return true;
     if (visited[nextRegion->getRegionNumber()])
       continue;
@@ -253,6 +263,18 @@ static bool isRegionReachable(Region *begin, Region *r) {
   return false;
 }
 
+/// Return `true` if region `r` is reachable from region `begin` according to
+/// the RegionBranchOpInterface (by taking a branch).
+static bool isRegionReachable(Region *begin, Region *r) {
+  assert(begin->getParentOp() == r->getParentOp() &&
+         "expected that both regions belong to the same op");
+  return traverseRegionGraph(
+      begin, [&](Region *nextRegion, const SmallVector<bool> &visited) {
+        // Interrupt traversal if `r` was reached.
+        return nextRegion == r;
+      });
+}
+
 /// Return `true` if `a` and `b` are in mutually exclusive regions.
 ///
 /// 1. Find the first common of `a` and `b` (ancestor) that implements
@@ -306,6 +328,21 @@ bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
   return isRegionReachable(region, region);
 }
 
+bool RegionBranchOpInterface::hasLoop() {
+  SmallVector<RegionSuccessor> entryRegions;
+  getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
+  for (RegionSuccessor successor : entryRegions)
+    if (!successor.isParent() &&
+        traverseRegionGraph(
+            successor.getSuccessor(),
+            [](Region *nextRegion, const SmallVector<bool> &visited) {
+              // Interrupt traversal if the region was already visited.
+              return visited[nextRegion->getRegionNumber()];
+            }))
+      return true;
+  return false;
+}
+
 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
   while (Region *region = op->getParentRegion()) {
     op = region->getParentOp();



More information about the Mlir-commits mailing list