[Mlir-commits] [mlir] b884f4e - [mlir][IR] Add ForwardDominanceIterator for IR walkers
Matthias Springer
llvmlistbot at llvm.org
Mon Mar 13 02:03:52 PDT 2023
Author: Matthias Springer
Date: 2023-03-13T09:58:34+01:00
New Revision: b884f4ef0a2de3d0f24111411dff663fd68c2eb0
URL: https://github.com/llvm/llvm-project/commit/b884f4ef0a2de3d0f24111411dff663fd68c2eb0
DIFF: https://github.com/llvm/llvm-project/commit/b884f4ef0a2de3d0f24111411dff663fd68c2eb0.diff
LOG: [mlir][IR] Add ForwardDominanceIterator for IR walkers
This iterator is similar to `ForwardIterator` but enumerates blocks according to their successor relationship. As a first use case, this new iterator is utilized in the dialect conversion framework.
Differential Revision: https://reviews.llvm.org/D144888
Added:
mlir/include/mlir/IR/Iterators.h
Modified:
mlir/include/mlir/IR/RegionKindInterface.h
mlir/include/mlir/IR/Visitors.h
mlir/lib/IR/RegionKindInterface.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/IR/visitors.mlir
mlir/test/lib/IR/TestVisitors.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Iterators.h b/mlir/include/mlir/IR/Iterators.h
new file mode 100644
index 0000000000000..c16f7117f3dc9
--- /dev/null
+++ b/mlir/include/mlir/IR/Iterators.h
@@ -0,0 +1,75 @@
+//===- Iterators.h - IR iterators for IR visitors ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The iterators defined in this file can be used together with IR visitors.
+// Note: These iterators cannot be defined in Visitors.h because that would
+// introduce a cyclic header dependency due to Operation.h.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_ITERATORS_H
+#define MLIR_IR_ITERATORS_H
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/RegionGraphTraits.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/ADT/DepthFirstIterator.h"
+
+namespace mlir {
+/// This iterator enumerates elements in "reverse" order. It is a wrapper around
+/// llvm::reverse.
+struct ReverseIterator {
+ // llvm::reverse uses RangeT::rbegin and RangeT::rend.
+ template <typename RangeT>
+ static constexpr auto makeIterable(RangeT &&range) {
+ return llvm::reverse(
+ ForwardIterator::makeIterable(std::forward<RangeT>(range)));
+ }
+};
+
+/// This iterator enumerates elements according to their dominance relationship.
+/// Operations and regions are enumerated in "forward" order. Blocks are
+/// enumerated according to their successor relationship. Unreachable blocks are
+/// not enumerated.
+///
+/// Note: If `NoGraphRegions` is set to "true", this iterator asserts that each
+/// visited region has SSA dominance. In either case, the ops in such regions
+/// are visited in forward order, but for regions without SSA dominance this
+/// does not guarantee that defining ops are visited before their users.
+template <bool NoGraphRegions = false>
+struct ForwardDominanceIterator {
+ static Block &makeIterable(Block &range) {
+ return ForwardIterator::makeIterable(range);
+ }
+
+ static auto makeIterable(Region ®ion) {
+ if (NoGraphRegions) {
+ // Only regions with SSA dominance are allowed.
+ assert(mayHaveSSADominance(region) && "graph regions are not allowed");
+ }
+
+ // Create DFS iterator. Blocks are enumerated according to their successor
+ // relationship.
+ Block *null = nullptr;
+ auto it = region.empty()
+ ? llvm::make_range(llvm::df_end(null), llvm::df_end(null))
+ : llvm::depth_first(®ion.front());
+
+ // Walk API expects Block references instead of pointers.
+ return llvm::make_pointee_range(it);
+ }
+
+ static MutableArrayRef<Region> makeIterable(Operation &range) {
+ return ForwardIterator::makeIterable(range);
+ }
+};
+} // namespace mlir
+
+#endif // MLIR_IR_ITERATORS_H
diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h
index a4b77d65cd654..46bfe717533a8 100644
--- a/mlir/include/mlir/IR/RegionKindInterface.h
+++ b/mlir/include/mlir/IR/RegionKindInterface.h
@@ -38,6 +38,11 @@ class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
};
} // namespace OpTrait
+/// Return "true" if the given region may have SSA dominance. This function also
+/// returns "true" in case the owner op is an unregistered op or an op that does
+/// not implement the RegionKindInterface.
+bool mayHaveSSADominance(Region ®ion);
+
} // namespace mlir
#include "mlir/IR/RegionKindInterface.h.inc"
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 871c19ff5c9b0..fe987741aa907 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -74,17 +74,6 @@ struct ForwardIterator {
}
};
-/// This iterator enumerates elements in "reverse" order. It is a wrapper around
-/// llvm::reverse.
-struct ReverseIterator {
- template <typename RangeT>
- static constexpr auto makeIterable(RangeT &&range) {
- // llvm::reverse uses RangeT::rbegin and RangeT::rend.
- return llvm::reverse(
- ForwardIterator::makeIterable(std::forward<RangeT>(range)));
- }
-};
-
/// A utility class to encode the current walk stage for "generic" walkers.
/// When walking an operation, we can either choose a Pre/Post order walker
/// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/IR/RegionKindInterface.cpp b/mlir/lib/IR/RegionKindInterface.cpp
index 9950b9117d265..cbef3025a5dd6 100644
--- a/mlir/lib/IR/RegionKindInterface.cpp
+++ b/mlir/lib/IR/RegionKindInterface.cpp
@@ -16,3 +16,11 @@
using namespace mlir;
#include "mlir/IR/RegionKindInterface.cpp.inc"
+
+bool mlir::mayHaveSSADominance(Region ®ion) {
+ auto regionKindOp =
+ dyn_cast_if_present<RegionKindInterface>(region.getParentOp());
+ if (!regionKindOp)
+ return true;
+ return regionKindOp.hasSSADominance(region.getRegionNumber());
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 99f4bf6ba092f..f5e9e71506e38 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Iterators.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
@@ -27,55 +28,6 @@ using namespace mlir::detail;
#define DEBUG_TYPE "dialect-conversion"
-/// Recursively collect all of the operations to convert from within 'region'.
-/// If 'target' is nonnull, operations that are recursively legal have their
-/// regions pre-filtered to avoid considering them for legalization.
-static LogicalResult
-computeConversionSet(iterator_range<Region::iterator> region,
- Location regionLoc,
- SmallVectorImpl<Operation *> &toConvert,
- ConversionTarget *target = nullptr) {
- if (region.empty())
- return success();
-
- // Traverse starting from the entry block.
- SmallVector<Block *, 16> worklist(1, &*region.begin());
- DenseSet<Block *> visitedBlocks;
- visitedBlocks.insert(worklist.front());
- while (!worklist.empty()) {
- Block *block = worklist.pop_back_val();
-
- // Compute the conversion set of each of the nested operations.
- for (Operation &op : *block) {
- toConvert.emplace_back(&op);
-
- // Don't check this operation's children for conversion if the operation
- // is recursively legal.
- auto legalityInfo =
- target ? target->isLegal(&op)
- : std::optional<ConversionTarget::LegalOpDetails>();
- if (legalityInfo && legalityInfo->isRecursivelyLegal)
- continue;
- for (auto ®ion : op.getRegions()) {
- if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
- toConvert, target)))
- return failure();
- }
- }
-
- // Recurse to children that haven't been visited.
- for (Block *succ : block->getSuccessors())
- if (visitedBlocks.insert(succ).second)
- worklist.push_back(succ);
- }
-
- // Check that all blocks in the region were visited.
- if (llvm::any_of(llvm::drop_begin(region, 1),
- [&](Block &block) { return !visitedBlocks.count(&block); }))
- return emitError(regionLoc, "unreachable blocks were not converted");
- return success();
-}
-
/// A utility function to log a successful result for the given reason.
template <typename... Args>
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
@@ -957,10 +909,6 @@ struct ConversionPatternRewriterImpl {
void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent,
Region::iterator before);
- /// Notifies that the blocks of a region were cloned into another.
- void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
- Location origRegionLoc);
-
/// Notifies that a pattern match failed for the given reason.
LogicalResult
notifyMatchFailure(Location loc,
@@ -1467,20 +1415,6 @@ void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
blockActions.push_back(BlockAction::getMove(laterBlock, {®ion, nullptr}));
}
-void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
- iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
- for (Block &block : blocks)
- blockActions.push_back(BlockAction::getCreate(&block));
-
- // Compute the conversion set for the inlined region.
- auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
-
- // This original region has already had its conversion set computed, so there
- // shouldn't be any new failures.
- (void)result;
- assert(succeeded(result) && "expected region to have no unreachable blocks");
-}
-
LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
LLVM_DEBUG({
@@ -1640,12 +1574,15 @@ void ConversionPatternRewriter::cloneRegionBefore(Region ®ion,
IRMapping &mapping) {
if (region.empty())
return;
+
PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
- // Collect the range of the cloned blocks.
- auto clonedBeginIt = mapping.lookup(®ion.front())->getIterator();
- auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
- impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
+ for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
+ Block *cloned = mapping.lookup(&b);
+ impl->notifyCreatedBlock(cloned);
+ cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
+ [&](Operation *op) { notifyOperationInserted(op); });
+ }
}
void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
@@ -2454,11 +2391,16 @@ LogicalResult OperationConverter::convertOperations(
// Compute the set of operations and blocks to convert.
SmallVector<Operation *> toConvert;
for (auto *op : ops) {
- toConvert.emplace_back(op);
- for (auto ®ion : op->getRegions())
- if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
- toConvert, &target)))
- return failure();
+ op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
+ [&](Operation *op) {
+ toConvert.push_back(op);
+ // Don't check this operation's children for conversion if the
+ // operation is recursively legal.
+ auto legalityInfo = target.isLegal(op);
+ if (legalityInfo && legalityInfo->isRecursivelyLegal)
+ return WalkResult::skip();
+ return WalkResult::advance();
+ });
}
// Convert each operation and discard rewrites on failure.
diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir
index 0c9fabeee5377..ddbc334fa4eed 100644
--- a/mlir/test/IR/visitors.mlir
+++ b/mlir/test/IR/visitors.mlir
@@ -250,3 +250,85 @@ func.func @unstructured_cfg() {
// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
// CHECK: Erasing block ^bb0 from region 0 from operation 'func.func'
// CHECK: Erasing block ^bb0 from region 0 from operation 'builtin.module'
+
+// -----
+
+func.func @unordered_cfg_with_loop() {
+ "regionOp0"() ({
+ ^bb0:
+ %c = "op0"() : () -> (i1)
+ cf.cond_br %c, ^bb2, ^bb3
+ ^bb1:
+ "op1"(%val) : (i32) -> ()
+ cf.br ^bb5
+ ^bb2:
+ %val = "op2"() : () -> (i32)
+ cf.br ^bb1
+ ^bb3:
+ "op3"() : () -> ()
+ cf.br ^bb2
+ ^bb4:
+ "op4"() : () -> ()
+ cf.br ^bb2
+ ^bb5:
+ "op5"() : () -> ()
+ cf.br ^bb7
+ ^bb6:
+ "op6"() : () -> ()
+ cf.br ^bb6
+ ^bb7:
+ "op7"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// 4
+// |
+// v
+// 0 -> 2 --> 1 --> 5 --> 7
+// | ^
+// | | 6 --
+// | / ^ \
+// | / \ /
+// v / --
+// 3
+
+// CHECK-LABEL: Op forward dominance post-order visits
+// CHECK: Visiting op 'op0'
+// CHECK: Visiting op 'cf.cond_br'
+// CHECK: Visiting op 'op2'
+// CHECK: Visiting op 'cf.br'
+// CHECK: Visiting op 'op1'
+// CHECK: Visiting op 'cf.br'
+// CHECK: Visiting op 'op5'
+// CHECK: Visiting op 'cf.br'
+// CHECK: Visiting op 'op7'
+// CHECK: Visiting op 'op3'
+// CHECK: Visiting op 'cf.br'
+// CHECK-NOT: Visiting op 'op6'
+// CHECK: Visiting op 'regionOp0'
+// CHECK: Visiting op 'func.return'
+// CHECK: Visiting op 'func.func'
+
+// CHECK-LABEL: Block forward dominance post-order visits
+// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb5 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb7 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb3 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'func.func'
+
+// CHECK-LABEL: Region forward dominance post-order visits
+// CHECK: Visiting region 0 from operation 'regionOp0'
+// CHECK: Visiting region 0 from operation 'func.func'
+
+// CHECK-LABEL: Block pre-order erasures (skip)
+// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Cannot erase block ^bb0 from region 0 from operation 'regionOp0', still has uses
+// CHECK: Cannot erase block ^bb1 from region 0 from operation 'regionOp0', still has uses
+// CHECK: Erasing block ^bb2 from region 0 from operation 'regionOp0'
+// CHECK: Erasing block ^bb2 from region 0 from operation 'regionOp0'
+// CHECK: Cannot erase block ^bb2 from region 0 from operation 'regionOp0', still has uses
+// CHECK: Cannot erase block ^bb3 from region 0 from operation 'regionOp0', still has uses
+// CHECK: Cannot erase block ^bb4 from region 0 from operation 'regionOp0', still has uses
diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp
index 3211d10bab9fa..6ed4abc71b7db 100644
--- a/mlir/test/lib/IR/TestVisitors.cpp
+++ b/mlir/test/lib/IR/TestVisitors.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/FunctionInterfaces.h"
+#include "mlir/IR/Iterators.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
@@ -74,6 +76,23 @@ static void testPureCallbacks(Operation *op) {
llvm::outs() << "Region reverse post-order visits"
<< "\n";
op->walk<WalkOrder::PostOrder, ReverseIterator>(regionPure);
+
+ // This test case tests "NoGraphRegions = true", so start the walk with
+ // functions.
+ op->walk([&](FunctionOpInterface funcOp) {
+ llvm::outs() << "Op forward dominance post-order visits"
+ << "\n";
+ funcOp->walk<WalkOrder::PostOrder,
+ ForwardDominanceIterator</*NoGraphRegions=*/true>>(opPure);
+ llvm::outs() << "Block forward dominance post-order visits"
+ << "\n";
+ funcOp->walk<WalkOrder::PostOrder,
+ ForwardDominanceIterator</*NoGraphRegions=*/true>>(blockPure);
+ llvm::outs() << "Region forward dominance post-order visits"
+ << "\n";
+ funcOp->walk<WalkOrder::PostOrder,
+ ForwardDominanceIterator</*NoGraphRegions=*/true>>(regionPure);
+ });
}
/// Tests erasure callbacks that skip the walk.
@@ -98,11 +117,18 @@ static void testSkipErasureCallbacks(Operation *op) {
if (isa<ModuleOp>(parentOp) || isa<ModuleOp>(parentOp->getParentOp()))
return WalkResult::advance();
- llvm::outs() << "Erasing ";
- printBlock(block);
- llvm::outs() << "\n";
- block->erase();
- return WalkResult::skip();
+ if (block->use_empty()) {
+ llvm::outs() << "Erasing ";
+ printBlock(block);
+ llvm::outs() << "\n";
+ block->erase();
+ return WalkResult::skip();
+ } else {
+ llvm::outs() << "Cannot erase ";
+ printBlock(block);
+ llvm::outs() << ", still has uses\n";
+ return WalkResult::advance();
+ }
};
llvm::outs() << "Op pre-order erasures (skip)"
@@ -141,10 +167,16 @@ static void testNoSkipErasureCallbacks(Operation *op) {
op->erase();
};
auto noSkipBlockErasure = [](Block *block) {
- llvm::outs() << "Erasing ";
- printBlock(block);
- llvm::outs() << "\n";
- block->erase();
+ if (block->use_empty()) {
+ llvm::outs() << "Erasing ";
+ printBlock(block);
+ llvm::outs() << "\n";
+ block->erase();
+ } else {
+ llvm::outs() << "Cannot erase ";
+ printBlock(block);
+ llvm::outs() << ", still has uses\n";
+ }
};
llvm::outs() << "Op post-order erasures (no skip)"
More information about the Mlir-commits
mailing list