[Mlir-commits] [mlir] dd450f0 - [mlir][Interfaces][NFC] Move region loop detection to `RegionBranchOpInterface` (#77090)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 7 04:49:33 PST 2024


Author: Matthias Springer
Date: 2024-01-07T13:49:29+01:00
New Revision: dd450f08cfeb9da372cbe459058bc9ae9425f862

URL: https://github.com/llvm/llvm-project/commit/dd450f08cfeb9da372cbe459058bc9ae9425f862
DIFF: https://github.com/llvm/llvm-project/commit/dd450f08cfeb9da372cbe459058bc9ae9425f862.diff

LOG: [mlir][Interfaces][NFC] Move region loop detection to `RegionBranchOpInterface` (#77090)

`BufferPlacementTransformationBase::isLoop` checks if there a loop in
the region branching graph of an operation. This algorithm is similar to
`isRegionReachable` in the `RegionBranchOpInterface`. To avoid duplicate
code, `isRegionReachable` is generalized, so that it can be used to
detect region loops. A helper function
`RegionBranchOpInterface::hasLoop` is added.

This change also turns a recursive implementation into an iterative one,
which is the preferred implementation strategy in LLVM.

Also move the `isLoop` to `BufferOptimizations.cpp`, so that we can
gradually retire `BufferPlacementTransformationBase`. (This is so that
proper error handling can be added to `BufferViewFlowAnalysis`.)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index 85e9c47ad5302c..8b9c1f0b7282f6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -100,12 +100,6 @@ class BufferPlacementTransformationBase {
     return dom;
   }
 
-  /// Returns true if the given operation represents a loop by testing whether
-  /// it implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`.
-  /// In the case of a `RegionBranchOpInterface`, it checks all region-based
-  /// control-flow edges for cycles.
-  static bool isLoop(Operation *op);
-
   /// Constructs a new operation base using the given root operation.
   BufferPlacementTransformationBase(Operation *op);
 

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 120ddf01ebce5c..95ac5dea243aa4 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -272,6 +272,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     /// eventually branch back to the same region. (Maybe after passing through
     /// other regions.)
     bool isRepetitiveRegion(unsigned index);
+
+    /// Return `true` if there is a loop in the region branching graph. Only
+    /// reachable regions (starting from the entry regions) are considered.
+    bool hasLoop();
   }];
 }
 

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
index 9f5d6a466780ad..9dc2f262a51161 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
@@ -40,6 +40,25 @@ static bool isKnownControlFlowInterface(Operation *op) {
   return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);
 }
 
+/// Returns true if the given operation represents a loop by testing whether it
+/// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
+/// the case of a `RegionBranchOpInterface`, it checks all region-based control-
+/// flow edges for cycles.
+static bool isLoop(Operation *op) {
+  // If the operation implements the `LoopLikeOpInterface` it can be considered
+  // a loop.
+  if (isa<LoopLikeOpInterface>(op))
+    return true;
+
+  // If the operation does not implement the `RegionBranchOpInterface`, it is
+  // (currently) not possible to detect a loop.
+  auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
+  if (!regionInterface)
+    return false;
+
+  return regionInterface.hasLoop();
+}
+
 /// Returns true if the given operation implements the AllocationOpInterface
 /// and it supports the dominate block hoisting.
 static bool allowAllocDominateBlockHoisting(Operation *op) {
@@ -115,8 +134,7 @@ static bool hasAllocationScope(Value alloc,
       // Check if the operation is a known control flow interface and break the
       // loop to avoid transformation in loops. Furthermore skip transformation
       // if the operation does not implement a RegionBeanchOpInterface.
-      if (BufferPlacementTransformationBase::isLoop(parentOp) ||
-          !isKnownControlFlowInterface(parentOp))
+      if (isLoop(parentOp) || !isKnownControlFlowInterface(parentOp))
         break;
     }
   } while ((region = region->getParentRegion()));
@@ -290,9 +308,7 @@ struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {
   }
 
   /// Returns true if the given operation does not represent a loop.
-  bool isLegalPlacement(Operation *op) {
-    return !BufferPlacementTransformationBase::isLoop(op);
-  }
+  bool isLegalPlacement(Operation *op) { return !isLoop(op); }
 
   /// Returns true if the given operation should be considered for hoisting.
   static bool shouldHoistOpType(Operation *op) {
@@ -327,7 +343,7 @@ struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
   /// given loop operation. If this is the case, it indicates that the
   /// allocation is passed via a back edge.
   bool isLegalPlacement(Operation *op) {
-    return BufferPlacementTransformationBase::isLoop(op) &&
+    return isLoop(op) &&
            !dominators->dominates(aliasDominatorBlock, op->getBlock());
   }
 

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 119801f9cc92f3..8fffdbf664c3f4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -96,53 +96,6 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
     Operation *op)
     : aliases(op), allocs(op), liveness(op) {}
 
-/// Returns true if the given operation represents a loop by testing whether it
-/// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
-/// the case of a `RegionBranchOpInterface`, it checks all region-based control-
-/// flow edges for cycles.
-bool BufferPlacementTransformationBase::isLoop(Operation *op) {
-  // If the operation implements the `LoopLikeOpInterface` it can be considered
-  // a loop.
-  if (isa<LoopLikeOpInterface>(op))
-    return true;
-
-  // If the operation does not implement the `RegionBranchOpInterface`, it is
-  // (currently) not possible to detect a loop.
-  RegionBranchOpInterface regionInterface;
-  if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
-    return false;
-
-  // Recurses into a region using the current region interface to find potential
-  // cycles.
-  SmallPtrSet<Region *, 4> visitedRegions;
-  std::function<bool(Region *)> recurse = [&](Region *current) {
-    if (!current)
-      return false;
-    // If we have found a back edge, the parent operation induces a loop.
-    if (!visitedRegions.insert(current).second)
-      return true;
-    // Recurses into all region successors.
-    SmallVector<RegionSuccessor, 2> successors;
-    regionInterface.getSuccessorRegions(current, successors);
-    for (RegionSuccessor &regionEntry : successors)
-      if (recurse(regionEntry.getSuccessor()))
-        return true;
-    return false;
-  };
-
-  // Start with all entry regions and test whether they induce a loop.
-  SmallVector<RegionSuccessor, 2> successorRegions;
-  regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
-                                      successorRegions);
-  for (RegionSuccessor &regionEntry : successorRegions) {
-    if (recurse(regionEntry.getSuccessor()))
-      return true;
-    visitedRegions.clear();
-  }
-
-  return false;
-}
-
 //===----------------------------------------------------------------------===//
 // BufferPlacementTransformationBase
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index a563ec5cb8db58..39e5e9997f7dd1 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -219,11 +219,18 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   return success();
 }
 
-/// Return `true` if region `r` is reachable from region `begin` according to
-/// the RegionBranchOpInterface (by taking a branch).
-static bool isRegionReachable(Region *begin, Region *r) {
-  assert(begin->getParentOp() == r->getParentOp() &&
-         "expected that both regions belong to the same op");
+/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
+/// this function returns "true" for a successor region. The first parameter is
+/// the successor region. The second parameter indicates all already visited
+/// regions.
+using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
+
+/// Traverse the region graph starting at `begin`. The traversal is interrupted
+/// if `stopCondition` evaluates to "true" for a successor region. In that case,
+/// this function returns "true". Otherwise, if the traversal was not
+/// interrupted, this function returns "false".
+static bool traverseRegionGraph(Region *begin,
+                                StopConditionFn stopConditionFn) {
   auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
   SmallVector<bool> visited(op->getNumRegions(), false);
   visited[begin->getRegionNumber()] = true;
@@ -242,7 +249,7 @@ static bool isRegionReachable(Region *begin, Region *r) {
   // Process all regions in the worklist via DFS.
   while (!worklist.empty()) {
     Region *nextRegion = worklist.pop_back_val();
-    if (nextRegion == r)
+    if (stopConditionFn(nextRegion, visited))
       return true;
     if (visited[nextRegion->getRegionNumber()])
       continue;
@@ -253,6 +260,18 @@ static bool isRegionReachable(Region *begin, Region *r) {
   return false;
 }
 
+/// Return `true` if region `r` is reachable from region `begin` according to
+/// the RegionBranchOpInterface (by taking a branch).
+static bool isRegionReachable(Region *begin, Region *r) {
+  assert(begin->getParentOp() == r->getParentOp() &&
+         "expected that both regions belong to the same op");
+  return traverseRegionGraph(begin,
+                             [&](Region *nextRegion, ArrayRef<bool> visited) {
+                               // Interrupt traversal if `r` was reached.
+                               return nextRegion == r;
+                             });
+}
+
 /// Return `true` if `a` and `b` are in mutually exclusive regions.
 ///
 /// 1. Find the first common of `a` and `b` (ancestor) that implements
@@ -306,6 +325,21 @@ bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
   return isRegionReachable(region, region);
 }
 
+bool RegionBranchOpInterface::hasLoop() {
+  SmallVector<RegionSuccessor> entryRegions;
+  getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
+  for (RegionSuccessor successor : entryRegions)
+    if (!successor.isParent() &&
+        traverseRegionGraph(successor.getSuccessor(),
+                            [](Region *nextRegion, ArrayRef<bool> visited) {
+                              // Interrupt traversal if the region was already
+                              // visited.
+                              return visited[nextRegion->getRegionNumber()];
+                            }))
+      return true;
+  return false;
+}
+
 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
   while (Region *region = op->getParentRegion()) {
     op = region->getParentOp();


        


More information about the Mlir-commits mailing list