[Mlir-commits] [mlir] 0f4ba02 - [mlir][interfaces] Add helpers for detecting recursive regions
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 19 00:18:01 PDT 2022
Author: Matthias Springer
Date: 2022-04-19T16:13:32+09:00
New Revision: 0f4ba02db3985051adac07a87ca9da549c0eb8ad
URL: https://github.com/llvm/llvm-project/commit/0f4ba02db3985051adac07a87ca9da549c0eb8ad
DIFF: https://github.com/llvm/llvm-project/commit/0f4ba02db3985051adac07a87ca9da549c0eb8ad.diff
LOG: [mlir][interfaces] Add helpers for detecting recursive regions
Add helper functions to check if an op may be executed multiple times based on RegionBranchOpInterface.
Differential Revision: https://reviews.llvm.org/D123789
Added:
Modified:
mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/lib/Interfaces/ControlFlowInterfaces.cpp
mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index ff4304c03a8fa..e9be31ebcbef2 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -216,6 +216,16 @@ class InvocationBounds {
/// RegionBranchOpInterface.
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b);
+/// Return the first enclosing region of the given op that may be executed
+/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region
+/// exists.
+Region *getEnclosingRepetitiveRegion(Operation *op);
+
+/// Return the first enclosing region of the given Value that may be executed
+/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region
+/// exists.
+Region *getEnclosingRepetitiveRegion(Value value);
+
//===----------------------------------------------------------------------===//
// RegionBranchTerminatorOpInterface
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index ac805ea8f218a..198a38ca04422 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -211,6 +211,11 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
SmallVector<Attribute, 2> nullAttrs(getOperation()->getNumOperands());
getSuccessorRegions(index, nullAttrs, regions);
}
+
+ /// Return `true` if control flow originating from the given region may
+ /// eventually branch back to the same region. (Maybe after passing through
+ /// other regions.)
+ bool isRepetitiveRegion(unsigned index);
}];
}
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 69ed30ae7bdd5..2ed3a9f690b1b 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -309,6 +309,57 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
return false;
}
+bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
+ SmallVector<bool> visited(getOperation()->getNumRegions(), false);
+ visited[index] = true;
+
+ // Retrieve all successors of the region and enqueue them in the worklist.
+ SmallVector<unsigned> worklist;
+ auto enqueueAllSuccessors = [&](unsigned index) {
+ SmallVector<RegionSuccessor> successors;
+ this->getSuccessorRegions(index, successors);
+ for (RegionSuccessor successor : successors)
+ if (!successor.isParent())
+ worklist.push_back(successor.getSuccessor()->getRegionNumber());
+ };
+ enqueueAllSuccessors(index);
+
+ // Process all regions in the worklist via DFS.
+ while (!worklist.empty()) {
+ unsigned nextRegion = worklist.pop_back_val();
+ if (nextRegion == index)
+ return true;
+ if (visited[nextRegion])
+ continue;
+ visited[nextRegion] = true;
+ enqueueAllSuccessors(nextRegion);
+ }
+
+ return false;
+}
+
+Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
+ while (Region *region = op->getParentRegion()) {
+ op = region->getParentOp();
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+ return region;
+ }
+ return nullptr;
+}
+
+Region *mlir::getEnclosingRepetitiveRegion(Value value) {
+ Region *region = value.getParentRegion();
+ while (region) {
+ Operation *op = region->getParentOp();
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+ return region;
+ region = op->getParentRegion();
+ }
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// RegionBranchTerminatorOpInterface
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index 4f473161f3c8b..5f433219ccba5 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -42,6 +42,29 @@ struct MutuallyExclusiveRegionsOp
SmallVectorImpl<RegionSuccessor> ®ions) {}
};
+/// All regions of this op call each other in a large circle.
+struct LoopRegionsOp
+ : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
+ using Op::Op;
+ static const unsigned kNumRegions = 3;
+
+ static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+ static StringRef getOperationName() { return "cftest.loop_regions_op"; }
+
+ void getSuccessorRegions(Optional<unsigned> index,
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (index) {
+ if (*index == 1)
+ // This region also branches back to the parent.
+ regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
+ }
+ }
+};
+
/// Regions are executed sequentially.
struct SequentialRegionsOp
: public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
@@ -65,7 +88,8 @@ struct SequentialRegionsOp
struct CFTestDialect : Dialect {
explicit CFTestDialect(MLIRContext *ctx)
: Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
- addOperations<DummyOp, MutuallyExclusiveRegionsOp, SequentialRegionsOp>();
+ addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
+ SequentialRegionsOp>();
}
static StringRef getDialectNamespace() { return "cftest"; }
};
@@ -142,3 +166,52 @@ TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
}
+
+TEST(RegionBranchOpInterface, RecursiveRegions) {
+ const char *ir = R"MLIR(
+"cftest.loop_regions_op"() (
+ {"cftest.dummy_op"() : () -> ()}, // op1
+ {"cftest.dummy_op"() : () -> ()}, // op2
+ {"cftest.dummy_op"() : () -> ()} // op3
+ ) : () -> ()
+ )MLIR";
+
+ DialectRegistry registry;
+ registry.insert<CFTestDialect>();
+ MLIRContext ctx(registry);
+
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+ Operation *testOp = &module->getBody()->getOperations().front();
+ auto regionOp = cast<RegionBranchOpInterface>(testOp);
+ Operation *op1 = &testOp->getRegion(0).front().front();
+ Operation *op2 = &testOp->getRegion(1).front().front();
+ Operation *op3 = &testOp->getRegion(2).front().front();
+
+ EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
+ EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
+ EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
+ EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
+ EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
+ EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
+}
+
+TEST(RegionBranchOpInterface, NotRecursiveRegions) {
+ const char *ir = R"MLIR(
+"cftest.sequential_regions_op"() (
+ {"cftest.dummy_op"() : () -> ()}, // op1
+ {"cftest.dummy_op"() : () -> ()} // op2
+ ) : () -> ()
+ )MLIR";
+
+ DialectRegistry registry;
+ registry.insert<CFTestDialect>();
+ MLIRContext ctx(registry);
+
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+ Operation *testOp = &module->getBody()->getOperations().front();
+ Operation *op1 = &testOp->getRegion(0).front().front();
+ Operation *op2 = &testOp->getRegion(1).front().front();
+
+ EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
+ EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
+}
More information about the Mlir-commits
mailing list