[Mlir-commits] [mlir] [mlir][Interfaces][NFC] Move region loop detection to `RegionBranchOpInterface` (PR #77090)

Matthias Springer llvmlistbot at llvm.org
Fri Jan 5 06:11:37 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/77090

>From 33c35adac5a5ab4a93b47d9d38197ef7c66625c6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 5 Jan 2024 14:10:39 +0000
Subject: [PATCH] [mlir][Interfaces][NFC] Move region loop detection to
 `RegionBranchOpInterface`

`BufferPlacementTransformationBase::isLoop` checks if there a loop in the region branching graph of an operation. This algorithm is similar to `isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate code, `isRegionReachable` is generalized, so that it can be used to detect region loops. A helper function `RegionBranchOpInterface::hasLoop` is added.

This change also turns a recursive implementation into an iterative one, which is the preferred implementation strategy in LLVM.
---
 .../Bufferization/Transforms/BufferUtils.h    |  6 ---
 .../mlir/Interfaces/ControlFlowInterfaces.td  |  4 ++
 .../Transforms/BufferOptimizations.cpp        | 28 ++++++++---
 .../Bufferization/Transforms/BufferUtils.cpp  | 47 ------------------
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 49 ++++++++++++++++---
 5 files changed, 69 insertions(+), 65 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index 85e9c47ad5302c..8b9c1f0b7282f6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -100,12 +100,6 @@ class BufferPlacementTransformationBase {
     return dom;
   }
 
-  /// Returns true if the given operation represents a loop by testing whether
-  /// it implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`.
-  /// In the case of a `RegionBranchOpInterface`, it checks all region-based
-  /// control-flow edges for cycles.
-  static bool isLoop(Operation *op);
-
   /// Constructs a new operation base using the given root operation.
   BufferPlacementTransformationBase(Operation *op);
 
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 120ddf01ebce5c..95ac5dea243aa4 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -272,6 +272,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     /// eventually branch back to the same region. (Maybe after passing through
     /// other regions.)
     bool isRepetitiveRegion(unsigned index);
+
+    /// Return `true` if there is a loop in the region branching graph. Only
+    /// reachable regions (starting from the entry regions) are considered.
+    bool hasLoop();
   }];
 }
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
index 9f5d6a466780ad..9dc2f262a51161 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
@@ -40,6 +40,25 @@ static bool isKnownControlFlowInterface(Operation *op) {
   return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);
 }
 
+/// Returns true if the given operation represents a loop by testing whether it
+/// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
+/// the case of a `RegionBranchOpInterface`, it checks all region-based control-
+/// flow edges for cycles.
+static bool isLoop(Operation *op) {
+  // If the operation implements the `LoopLikeOpInterface` it can be considered
+  // a loop.
+  if (isa<LoopLikeOpInterface>(op))
+    return true;
+
+  // If the operation does not implement the `RegionBranchOpInterface`, it is
+  // (currently) not possible to detect a loop.
+  auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
+  if (!regionInterface)
+    return false;
+
+  return regionInterface.hasLoop();
+}
+
 /// Returns true if the given operation implements the AllocationOpInterface
 /// and it supports the dominate block hoisting.
 static bool allowAllocDominateBlockHoisting(Operation *op) {
@@ -115,8 +134,7 @@ static bool hasAllocationScope(Value alloc,
       // Check if the operation is a known control flow interface and break the
       // loop to avoid transformation in loops. Furthermore skip transformation
       // if the operation does not implement a RegionBeanchOpInterface.
-      if (BufferPlacementTransformationBase::isLoop(parentOp) ||
-          !isKnownControlFlowInterface(parentOp))
+      if (isLoop(parentOp) || !isKnownControlFlowInterface(parentOp))
         break;
     }
   } while ((region = region->getParentRegion()));
@@ -290,9 +308,7 @@ struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {
   }
 
   /// Returns true if the given operation does not represent a loop.
-  bool isLegalPlacement(Operation *op) {
-    return !BufferPlacementTransformationBase::isLoop(op);
-  }
+  bool isLegalPlacement(Operation *op) { return !isLoop(op); }
 
   /// Returns true if the given operation should be considered for hoisting.
   static bool shouldHoistOpType(Operation *op) {
@@ -327,7 +343,7 @@ struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
   /// given loop operation. If this is the case, it indicates that the
   /// allocation is passed via a back edge.
   bool isLegalPlacement(Operation *op) {
-    return BufferPlacementTransformationBase::isLoop(op) &&
+    return isLoop(op) &&
            !dominators->dominates(aliasDominatorBlock, op->getBlock());
   }
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 119801f9cc92f3..8fffdbf664c3f4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -96,53 +96,6 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
     Operation *op)
     : aliases(op), allocs(op), liveness(op) {}
 
-/// Returns true if the given operation represents a loop by testing whether it
-/// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
-/// the case of a `RegionBranchOpInterface`, it checks all region-based control-
-/// flow edges for cycles.
-bool BufferPlacementTransformationBase::isLoop(Operation *op) {
-  // If the operation implements the `LoopLikeOpInterface` it can be considered
-  // a loop.
-  if (isa<LoopLikeOpInterface>(op))
-    return true;
-
-  // If the operation does not implement the `RegionBranchOpInterface`, it is
-  // (currently) not possible to detect a loop.
-  RegionBranchOpInterface regionInterface;
-  if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
-    return false;
-
-  // Recurses into a region using the current region interface to find potential
-  // cycles.
-  SmallPtrSet<Region *, 4> visitedRegions;
-  std::function<bool(Region *)> recurse = [&](Region *current) {
-    if (!current)
-      return false;
-    // If we have found a back edge, the parent operation induces a loop.
-    if (!visitedRegions.insert(current).second)
-      return true;
-    // Recurses into all region successors.
-    SmallVector<RegionSuccessor, 2> successors;
-    regionInterface.getSuccessorRegions(current, successors);
-    for (RegionSuccessor &regionEntry : successors)
-      if (recurse(regionEntry.getSuccessor()))
-        return true;
-    return false;
-  };
-
-  // Start with all entry regions and test whether they induce a loop.
-  SmallVector<RegionSuccessor, 2> successorRegions;
-  regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
-                                      successorRegions);
-  for (RegionSuccessor &regionEntry : successorRegions) {
-    if (recurse(regionEntry.getSuccessor()))
-      return true;
-    visitedRegions.clear();
-  }
-
-  return false;
-}
-
 //===----------------------------------------------------------------------===//
 // BufferPlacementTransformationBase
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index a563ec5cb8db58..a1ea22dbfc6937 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -219,11 +219,21 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   return success();
 }
 
-/// Return `true` if region `r` is reachable from region `begin` according to
-/// the RegionBranchOpInterface (by taking a branch).
-static bool isRegionReachable(Region *begin, Region *r) {
-  assert(begin->getParentOp() == r->getParentOp() &&
-         "expected that both regions belong to the same op");
+namespace {
+/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
+/// this function returns "true" for a successor region. The first parameter is
+/// the successor region. The second parameter indicates all already visited
+/// regions.
+using StopConditionFn =
+    std::function<bool(Region *, const SmallVector<bool> &visited)>;
+} // namespace
+
+/// Traverse the region graph starting at `begin`. The traversal is interrupted
+/// if `stopCondition` evaluates to "true" for a successor region. In that case,
+/// this function returns "true". Otherwise, if the traversal was not
+/// interrupted, this function returns "false".
+static bool traverseRegionGraph(Region *begin,
+                                StopConditionFn stopConditionFn) {
   auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
   SmallVector<bool> visited(op->getNumRegions(), false);
   visited[begin->getRegionNumber()] = true;
@@ -242,7 +252,7 @@ static bool isRegionReachable(Region *begin, Region *r) {
   // Process all regions in the worklist via DFS.
   while (!worklist.empty()) {
     Region *nextRegion = worklist.pop_back_val();
-    if (nextRegion == r)
+    if (stopConditionFn(nextRegion, visited))
       return true;
     if (visited[nextRegion->getRegionNumber()])
       continue;
@@ -253,6 +263,18 @@ static bool isRegionReachable(Region *begin, Region *r) {
   return false;
 }
 
+/// Return `true` if region `r` is reachable from region `begin` according to
+/// the RegionBranchOpInterface (by taking a branch).
+static bool isRegionReachable(Region *begin, Region *r) {
+  assert(begin->getParentOp() == r->getParentOp() &&
+         "expected that both regions belong to the same op");
+  return traverseRegionGraph(
+      begin, [&](Region *nextRegion, const SmallVector<bool> &visited) {
+        // Interrupt traversal if `r` was reached.
+        return nextRegion == r;
+      });
+}
+
 /// Return `true` if `a` and `b` are in mutually exclusive regions.
 ///
 /// 1. Find the first common of `a` and `b` (ancestor) that implements
@@ -306,6 +328,21 @@ bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
   return isRegionReachable(region, region);
 }
 
+bool RegionBranchOpInterface::hasLoop() {
+  SmallVector<RegionSuccessor> entryRegions;
+  getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
+  for (RegionSuccessor successor : entryRegions)
+    if (!successor.isParent() &&
+        traverseRegionGraph(
+            successor.getSuccessor(),
+            [](Region *nextRegion, const SmallVector<bool> &visited) {
+              // Interrupt traversal if the region was already visited.
+              return visited[nextRegion->getRegionNumber()];
+            }))
+      return true;
+  return false;
+}
+
 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
   while (Region *region = op->getParentRegion()) {
     op = region->getParentOp();



More information about the Mlir-commits mailing list