[Mlir-commits] [mlir] b884f4e - [mlir][IR] Add ForwardDominanceIterator for IR walkers

Matthias Springer llvmlistbot at llvm.org
Mon Mar 13 02:03:52 PDT 2023


Author: Matthias Springer
Date: 2023-03-13T09:58:34+01:00
New Revision: b884f4ef0a2de3d0f24111411dff663fd68c2eb0

URL: https://github.com/llvm/llvm-project/commit/b884f4ef0a2de3d0f24111411dff663fd68c2eb0
DIFF: https://github.com/llvm/llvm-project/commit/b884f4ef0a2de3d0f24111411dff663fd68c2eb0.diff

LOG: [mlir][IR] Add ForwardDominanceIterator for IR walkers

This iterator is similar to `ForwardIterator` but enumerates blocks according to their successor relationship. As a first use case, this new iterator is utilized in the dialect conversion framework.

Differential Revision: https://reviews.llvm.org/D144888

Added: 
    mlir/include/mlir/IR/Iterators.h

Modified: 
    mlir/include/mlir/IR/RegionKindInterface.h
    mlir/include/mlir/IR/Visitors.h
    mlir/lib/IR/RegionKindInterface.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/IR/visitors.mlir
    mlir/test/lib/IR/TestVisitors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Iterators.h b/mlir/include/mlir/IR/Iterators.h
new file mode 100644
index 0000000000000..c16f7117f3dc9
--- /dev/null
+++ b/mlir/include/mlir/IR/Iterators.h
@@ -0,0 +1,75 @@
+//===- Iterators.h - IR iterators for IR visitors ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The iterators defined in this file can be used together with IR visitors.
+// Note: These iterators cannot be defined in Visitors.h because that would
+// introduce a cyclic header dependency due to Operation.h.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_ITERATORS_H
+#define MLIR_IR_ITERATORS_H
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/RegionGraphTraits.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/ADT/DepthFirstIterator.h"
+
+namespace mlir {
+/// This iterator enumerates elements in "reverse" order. It is a wrapper around
+/// llvm::reverse.
+struct ReverseIterator {
+  // llvm::reverse uses RangeT::rbegin and RangeT::rend.
+  template <typename RangeT>
+  static constexpr auto makeIterable(RangeT &&range) {
+    return llvm::reverse(
+        ForwardIterator::makeIterable(std::forward<RangeT>(range)));
+  }
+};
+
+/// This iterator enumerates elements according to their dominance relationship.
+/// Operations and regions are enumerated in "forward" order. Blocks are
+/// enumerated according to their successor relationship. Unreachable blocks are
+/// not enumerated.
+///
+/// Note: If `NoGraphRegions` is set to "true", this iterator asserts that each
+/// visited region has SSA dominance. In either case, the ops in such regions
+/// are visited in forward order, but for regions without SSA dominance this
+/// does not guarantee that defining ops are visited before their users.
+template <bool NoGraphRegions = false>
+struct ForwardDominanceIterator {
+  static Block &makeIterable(Block &range) {
+    return ForwardIterator::makeIterable(range);
+  }
+
+  static auto makeIterable(Region &region) {
+    if (NoGraphRegions) {
+      // Only regions with SSA dominance are allowed.
+      assert(mayHaveSSADominance(region) && "graph regions are not allowed");
+    }
+
+    // Create DFS iterator. Blocks are enumerated according to their successor
+    // relationship.
+    Block *null = nullptr;
+    auto it = region.empty()
+                  ? llvm::make_range(llvm::df_end(null), llvm::df_end(null))
+                  : llvm::depth_first(&region.front());
+
+    // Walk API expects Block references instead of pointers.
+    return llvm::make_pointee_range(it);
+  }
+
+  static MutableArrayRef<Region> makeIterable(Operation &range) {
+    return ForwardIterator::makeIterable(range);
+  }
+};
+} // namespace mlir
+
+#endif // MLIR_IR_ITERATORS_H

diff  --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h
index a4b77d65cd654..46bfe717533a8 100644
--- a/mlir/include/mlir/IR/RegionKindInterface.h
+++ b/mlir/include/mlir/IR/RegionKindInterface.h
@@ -38,6 +38,11 @@ class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
 };
 } // namespace OpTrait
 
+/// Return "true" if the given region may have SSA dominance. This function also
+/// returns "true" in case the owner op is an unregistered op or an op that does
+/// not implement the RegionKindInterface.
+bool mayHaveSSADominance(Region &region);
+
 } // namespace mlir
 
 #include "mlir/IR/RegionKindInterface.h.inc"

diff  --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 871c19ff5c9b0..fe987741aa907 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -74,17 +74,6 @@ struct ForwardIterator {
   }
 };
 
-/// This iterator enumerates elements in "reverse" order. It is a wrapper around
-/// llvm::reverse.
-struct ReverseIterator {
-  template <typename RangeT>
-  static constexpr auto makeIterable(RangeT &&range) {
-    // llvm::reverse uses RangeT::rbegin and RangeT::rend.
-    return llvm::reverse(
-        ForwardIterator::makeIterable(std::forward<RangeT>(range)));
-  }
-};
-
 /// A utility class to encode the current walk stage for "generic" walkers.
 /// When walking an operation, we can either choose a Pre/Post order walker
 /// which invokes the callback on an operation before/after all its attached

diff  --git a/mlir/lib/IR/RegionKindInterface.cpp b/mlir/lib/IR/RegionKindInterface.cpp
index 9950b9117d265..cbef3025a5dd6 100644
--- a/mlir/lib/IR/RegionKindInterface.cpp
+++ b/mlir/lib/IR/RegionKindInterface.cpp
@@ -16,3 +16,11 @@
 using namespace mlir;
 
 #include "mlir/IR/RegionKindInterface.cpp.inc"
+
+bool mlir::mayHaveSSADominance(Region &region) {
+  auto regionKindOp =
+      dyn_cast_if_present<RegionKindInterface>(region.getParentOp());
+  if (!regionKindOp)
+    return true;
+  return regionKindOp.hasSSADominance(region.getRegionNumber());
+}

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 99f4bf6ba092f..f5e9e71506e38 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/FunctionInterfaces.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Iterators.h"
 #include "mlir/Rewrite/PatternApplicator.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SetVector.h"
@@ -27,55 +28,6 @@ using namespace mlir::detail;
 
 #define DEBUG_TYPE "dialect-conversion"
 
-/// Recursively collect all of the operations to convert from within 'region'.
-/// If 'target' is nonnull, operations that are recursively legal have their
-/// regions pre-filtered to avoid considering them for legalization.
-static LogicalResult
-computeConversionSet(iterator_range<Region::iterator> region,
-                     Location regionLoc,
-                     SmallVectorImpl<Operation *> &toConvert,
-                     ConversionTarget *target = nullptr) {
-  if (region.empty())
-    return success();
-
-  // Traverse starting from the entry block.
-  SmallVector<Block *, 16> worklist(1, &*region.begin());
-  DenseSet<Block *> visitedBlocks;
-  visitedBlocks.insert(worklist.front());
-  while (!worklist.empty()) {
-    Block *block = worklist.pop_back_val();
-
-    // Compute the conversion set of each of the nested operations.
-    for (Operation &op : *block) {
-      toConvert.emplace_back(&op);
-
-      // Don't check this operation's children for conversion if the operation
-      // is recursively legal.
-      auto legalityInfo =
-          target ? target->isLegal(&op)
-                 : std::optional<ConversionTarget::LegalOpDetails>();
-      if (legalityInfo && legalityInfo->isRecursivelyLegal)
-        continue;
-      for (auto &region : op.getRegions()) {
-        if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
-                                        toConvert, target)))
-          return failure();
-      }
-    }
-
-    // Recurse to children that haven't been visited.
-    for (Block *succ : block->getSuccessors())
-      if (visitedBlocks.insert(succ).second)
-        worklist.push_back(succ);
-  }
-
-  // Check that all blocks in the region were visited.
-  if (llvm::any_of(llvm::drop_begin(region, 1),
-                   [&](Block &block) { return !visitedBlocks.count(&block); }))
-    return emitError(regionLoc, "unreachable blocks were not converted");
-  return success();
-}
-
 /// A utility function to log a successful result for the given reason.
 template <typename... Args>
 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
@@ -957,10 +909,6 @@ struct ConversionPatternRewriterImpl {
   void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
                                         Region::iterator before);
 
-  /// Notifies that the blocks of a region were cloned into another.
-  void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
-                                   Location origRegionLoc);
-
   /// Notifies that a pattern match failed for the given reason.
   LogicalResult
   notifyMatchFailure(Location loc,
@@ -1467,20 +1415,6 @@ void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
   blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
 }
 
-void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
-    iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
-  for (Block &block : blocks)
-    blockActions.push_back(BlockAction::getCreate(&block));
-
-  // Compute the conversion set for the inlined region.
-  auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
-
-  // This original region has already had its conversion set computed, so there
-  // shouldn't be any new failures.
-  (void)result;
-  assert(succeeded(result) && "expected region to have no unreachable blocks");
-}
-
 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
@@ -1640,12 +1574,15 @@ void ConversionPatternRewriter::cloneRegionBefore(Region &region,
                                                   IRMapping &mapping) {
   if (region.empty())
     return;
+
   PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
 
-  // Collect the range of the cloned blocks.
-  auto clonedBeginIt = mapping.lookup(&region.front())->getIterator();
-  auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
-  impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
+  for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
+    Block *cloned = mapping.lookup(&b);
+    impl->notifyCreatedBlock(cloned);
+    cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
+        [&](Operation *op) { notifyOperationInserted(op); });
+  }
 }
 
 void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
@@ -2454,11 +2391,16 @@ LogicalResult OperationConverter::convertOperations(
   // Compute the set of operations and blocks to convert.
   SmallVector<Operation *> toConvert;
   for (auto *op : ops) {
-    toConvert.emplace_back(op);
-    for (auto &region : op->getRegions())
-      if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
-                                      toConvert, &target)))
-        return failure();
+    op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
+        [&](Operation *op) {
+          toConvert.push_back(op);
+          // Don't check this operation's children for conversion if the
+          // operation is recursively legal.
+          auto legalityInfo = target.isLegal(op);
+          if (legalityInfo && legalityInfo->isRecursivelyLegal)
+            return WalkResult::skip();
+          return WalkResult::advance();
+        });
   }
 
   // Convert each operation and discard rewrites on failure.

diff  --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir
index 0c9fabeee5377..ddbc334fa4eed 100644
--- a/mlir/test/IR/visitors.mlir
+++ b/mlir/test/IR/visitors.mlir
@@ -250,3 +250,85 @@ func.func @unstructured_cfg() {
 // CHECK:       Erasing block ^bb0 from region 0 from operation 'regionOp0'
 // CHECK:       Erasing block ^bb0 from region 0 from operation 'func.func'
 // CHECK:       Erasing block ^bb0 from region 0 from operation 'builtin.module'
+
+// -----
+
+func.func @unordered_cfg_with_loop() {
+  "regionOp0"() ({
+    ^bb0:
+      %c = "op0"() : () -> (i1)
+      cf.cond_br %c, ^bb2, ^bb3
+    ^bb1:
+      "op1"(%val) : (i32) -> ()
+      cf.br ^bb5
+    ^bb2:
+      %val = "op2"() : () -> (i32)
+      cf.br ^bb1
+    ^bb3:
+      "op3"() : () -> ()
+      cf.br ^bb2
+    ^bb4:
+      "op4"() : () -> ()
+      cf.br ^bb2
+    ^bb5:
+      "op5"() : () -> ()
+      cf.br ^bb7
+    ^bb6:
+      "op6"() : () -> ()
+      cf.br ^bb6
+    ^bb7:
+      "op7"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+//      4
+//      |
+//      v
+// 0 -> 2 --> 1 --> 5 --> 7
+// |    ^
+// |    |      6 --
+// |   /       ^   \
+// |  /         \  /
+// v /           --
+//  3
+
+// CHECK-LABEL: Op forward dominance post-order visits
+// CHECK:       Visiting op 'op0'
+// CHECK:       Visiting op 'cf.cond_br'
+// CHECK:       Visiting op 'op2'
+// CHECK:       Visiting op 'cf.br'
+// CHECK:       Visiting op 'op1'
+// CHECK:       Visiting op 'cf.br'
+// CHECK:       Visiting op 'op5'
+// CHECK:       Visiting op 'cf.br'
+// CHECK:       Visiting op 'op7'
+// CHECK:       Visiting op 'op3'
+// CHECK:       Visiting op 'cf.br'
+// CHECK-NOT:   Visiting op 'op6'
+// CHECK:       Visiting op 'regionOp0'
+// CHECK:       Visiting op 'func.return'
+// CHECK:       Visiting op 'func.func'
+
+// CHECK-LABEL: Block forward dominance post-order visits
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb2 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb1 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb5 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb7 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb3 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'func.func'
+
+// CHECK-LABEL: Region forward dominance post-order visits
+// CHECK:       Visiting region 0 from operation 'regionOp0'
+// CHECK:       Visiting region 0 from operation 'func.func'
+
+// CHECK-LABEL: Block pre-order erasures (skip)
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Cannot erase block ^bb0 from region 0 from operation 'regionOp0', still has uses
+// CHECK:       Cannot erase block ^bb1 from region 0 from operation 'regionOp0', still has uses
+// CHECK:       Erasing block ^bb2 from region 0 from operation 'regionOp0'
+// CHECK:       Erasing block ^bb2 from region 0 from operation 'regionOp0'
+// CHECK:       Cannot erase block ^bb2 from region 0 from operation 'regionOp0', still has uses
+// CHECK:       Cannot erase block ^bb3 from region 0 from operation 'regionOp0', still has uses
+// CHECK:       Cannot erase block ^bb4 from region 0 from operation 'regionOp0', still has uses

diff  --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp
index 3211d10bab9fa..6ed4abc71b7db 100644
--- a/mlir/test/lib/IR/TestVisitors.cpp
+++ b/mlir/test/lib/IR/TestVisitors.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/FunctionInterfaces.h"
+#include "mlir/IR/Iterators.h"
 #include "mlir/Pass/Pass.h"
 
 using namespace mlir;
@@ -74,6 +76,23 @@ static void testPureCallbacks(Operation *op) {
   llvm::outs() << "Region reverse post-order visits"
                << "\n";
   op->walk<WalkOrder::PostOrder, ReverseIterator>(regionPure);
+
+  // This test case tests "NoGraphRegions = true", so start the walk with
+  // functions.
+  op->walk([&](FunctionOpInterface funcOp) {
+    llvm::outs() << "Op forward dominance post-order visits"
+                 << "\n";
+    funcOp->walk<WalkOrder::PostOrder,
+                 ForwardDominanceIterator</*NoGraphRegions=*/true>>(opPure);
+    llvm::outs() << "Block forward dominance post-order visits"
+                 << "\n";
+    funcOp->walk<WalkOrder::PostOrder,
+                 ForwardDominanceIterator</*NoGraphRegions=*/true>>(blockPure);
+    llvm::outs() << "Region forward dominance post-order visits"
+                 << "\n";
+    funcOp->walk<WalkOrder::PostOrder,
+                 ForwardDominanceIterator</*NoGraphRegions=*/true>>(regionPure);
+  });
 }
 
 /// Tests erasure callbacks that skip the walk.
@@ -98,11 +117,18 @@ static void testSkipErasureCallbacks(Operation *op) {
     if (isa<ModuleOp>(parentOp) || isa<ModuleOp>(parentOp->getParentOp()))
       return WalkResult::advance();
 
-    llvm::outs() << "Erasing ";
-    printBlock(block);
-    llvm::outs() << "\n";
-    block->erase();
-    return WalkResult::skip();
+    if (block->use_empty()) {
+      llvm::outs() << "Erasing ";
+      printBlock(block);
+      llvm::outs() << "\n";
+      block->erase();
+      return WalkResult::skip();
+    } else {
+      llvm::outs() << "Cannot erase ";
+      printBlock(block);
+      llvm::outs() << ", still has uses\n";
+      return WalkResult::advance();
+    }
   };
 
   llvm::outs() << "Op pre-order erasures (skip)"
@@ -141,10 +167,16 @@ static void testNoSkipErasureCallbacks(Operation *op) {
     op->erase();
   };
   auto noSkipBlockErasure = [](Block *block) {
-    llvm::outs() << "Erasing ";
-    printBlock(block);
-    llvm::outs() << "\n";
-    block->erase();
+    if (block->use_empty()) {
+      llvm::outs() << "Erasing ";
+      printBlock(block);
+      llvm::outs() << "\n";
+      block->erase();
+    } else {
+      llvm::outs() << "Cannot erase ";
+      printBlock(block);
+      llvm::outs() << ", still has uses\n";
+    }
   };
 
   llvm::outs() << "Op post-order erasures (no skip)"


        


More information about the Mlir-commits mailing list