[Mlir-commits] [mlir] 0f4ba02 - [mlir][interfaces] Add helpers for detecting recursive regions

Matthias Springer llvmlistbot at llvm.org
Tue Apr 19 00:18:01 PDT 2022


Author: Matthias Springer
Date: 2022-04-19T16:13:32+09:00
New Revision: 0f4ba02db3985051adac07a87ca9da549c0eb8ad

URL: https://github.com/llvm/llvm-project/commit/0f4ba02db3985051adac07a87ca9da549c0eb8ad
DIFF: https://github.com/llvm/llvm-project/commit/0f4ba02db3985051adac07a87ca9da549c0eb8ad.diff

LOG: [mlir][interfaces] Add helpers for detecting recursive regions

Add helper functions to check if an op may be executed multiple times based on RegionBranchOpInterface.

Differential Revision: https://reviews.llvm.org/D123789

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index ff4304c03a8fa..e9be31ebcbef2 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -216,6 +216,16 @@ class InvocationBounds {
 /// RegionBranchOpInterface.
 bool insideMutuallyExclusiveRegions(Operation *a, Operation *b);
 
+/// Return the first enclosing region of the given op that may be executed
+/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region
+/// exists.
+Region *getEnclosingRepetitiveRegion(Operation *op);
+
+/// Return the first enclosing region of the given Value that may be executed
+/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region
+/// exists.
+Region *getEnclosingRepetitiveRegion(Value value);
+
 //===----------------------------------------------------------------------===//
 // RegionBranchTerminatorOpInterface
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index ac805ea8f218a..198a38ca04422 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -211,6 +211,11 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
        SmallVector<Attribute, 2> nullAttrs(getOperation()->getNumOperands());
        getSuccessorRegions(index, nullAttrs, regions);
     }
+
+    /// Return `true` if control flow originating from the given region may
+    /// eventually branch back to the same region. (Maybe after passing through
+    /// other regions.)
+    bool isRepetitiveRegion(unsigned index);
   }];
 }
 

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 69ed30ae7bdd5..2ed3a9f690b1b 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -309,6 +309,57 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
   return false;
 }
 
+bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
+  SmallVector<bool> visited(getOperation()->getNumRegions(), false);
+  visited[index] = true;
+
+  // Retrieve all successors of the region and enqueue them in the worklist.
+  SmallVector<unsigned> worklist;
+  auto enqueueAllSuccessors = [&](unsigned index) {
+    SmallVector<RegionSuccessor> successors;
+    this->getSuccessorRegions(index, successors);
+    for (RegionSuccessor successor : successors)
+      if (!successor.isParent())
+        worklist.push_back(successor.getSuccessor()->getRegionNumber());
+  };
+  enqueueAllSuccessors(index);
+
+  // Process all regions in the worklist via DFS.
+  while (!worklist.empty()) {
+    unsigned nextRegion = worklist.pop_back_val();
+    if (nextRegion == index)
+      return true;
+    if (visited[nextRegion])
+      continue;
+    visited[nextRegion] = true;
+    enqueueAllSuccessors(nextRegion);
+  }
+
+  return false;
+}
+
+Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
+  while (Region *region = op->getParentRegion()) {
+    op = region->getParentOp();
+    if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
+      if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+        return region;
+  }
+  return nullptr;
+}
+
+Region *mlir::getEnclosingRepetitiveRegion(Value value) {
+  Region *region = value.getParentRegion();
+  while (region) {
+    Operation *op = region->getParentOp();
+    if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
+      if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+        return region;
+    region = op->getParentRegion();
+  }
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // RegionBranchTerminatorOpInterface
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index 4f473161f3c8b..5f433219ccba5 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -42,6 +42,29 @@ struct MutuallyExclusiveRegionsOp
                            SmallVectorImpl<RegionSuccessor> &regions) {}
 };
 
+/// All regions of this op call each other in a large circle.
+struct LoopRegionsOp
+    : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
+  using Op::Op;
+  static const unsigned kNumRegions = 3;
+
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static StringRef getOperationName() { return "cftest.loop_regions_op"; }
+
+  void getSuccessorRegions(Optional<unsigned> index,
+                           ArrayRef<Attribute> operands,
+                           SmallVectorImpl<RegionSuccessor> &regions) {
+    if (index) {
+      if (*index == 1)
+        // This region also branches back to the parent.
+        regions.push_back(RegionSuccessor());
+      regions.push_back(
+          RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
+    }
+  }
+};
+
 /// Regions are executed sequentially.
 struct SequentialRegionsOp
     : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
@@ -65,7 +88,8 @@ struct SequentialRegionsOp
 struct CFTestDialect : Dialect {
   explicit CFTestDialect(MLIRContext *ctx)
       : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
-    addOperations<DummyOp, MutuallyExclusiveRegionsOp, SequentialRegionsOp>();
+    addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
+                  SequentialRegionsOp>();
   }
   static StringRef getDialectNamespace() { return "cftest"; }
 };
@@ -142,3 +166,52 @@ TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
   EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
 }
+
+TEST(RegionBranchOpInterface, RecursiveRegions) {
+  const char *ir = R"MLIR(
+"cftest.loop_regions_op"() (
+      {"cftest.dummy_op"() : () -> ()},  // op1
+      {"cftest.dummy_op"() : () -> ()},  // op2
+      {"cftest.dummy_op"() : () -> ()}   // op3
+  ) : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<CFTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+  Operation *testOp = &module->getBody()->getOperations().front();
+  auto regionOp = cast<RegionBranchOpInterface>(testOp);
+  Operation *op1 = &testOp->getRegion(0).front().front();
+  Operation *op2 = &testOp->getRegion(1).front().front();
+  Operation *op3 = &testOp->getRegion(2).front().front();
+
+  EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
+  EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
+  EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
+  EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
+  EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
+  EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
+}
+
+TEST(RegionBranchOpInterface, NotRecursiveRegions) {
+  const char *ir = R"MLIR(
+"cftest.sequential_regions_op"() (
+      {"cftest.dummy_op"() : () -> ()},  // op1
+      {"cftest.dummy_op"() : () -> ()}   // op2
+  ) : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<CFTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+  Operation *testOp = &module->getBody()->getOperations().front();
+  Operation *op1 = &testOp->getRegion(0).front().front();
+  Operation *op2 = &testOp->getRegion(1).front().front();
+
+  EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
+  EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
+}


        


More information about the Mlir-commits mailing list