[Mlir-commits] [mlir] 86bbbb3 - [mlir] Extended Dominance analysis with a function to find the nearest common dominator of two given blocks.

Marcel Koester llvmlistbot at llvm.org
Fri Mar 27 06:57:25 PDT 2020


Author: Marcel Koester
Date: 2020-03-27T14:55:40+01:00
New Revision: 86bbbb317bce06e4a8cd084c85663b10147eff65

URL: https://github.com/llvm/llvm-project/commit/86bbbb317bce06e4a8cd084c85663b10147eff65
DIFF: https://github.com/llvm/llvm-project/commit/86bbbb317bce06e4a8cd084c85663b10147eff65.diff

LOG: [mlir] Extended Dominance analysis with a function to find the nearest common dominator of two given blocks.

The Dominance analysis currently misses a utility function to find the nearest common dominator of two given blocks. This is required for a huge variety of different control-flow analyses and transformations. This commit adds this function and moves the getNode function from DominanceInfo to DominanceInfoBase, as it also works for post dominators.

Differential Revision: https://reviews.llvm.org/D75507

Added: 
    mlir/test/Analysis/test-dominance.mlir
    mlir/test/lib/Transforms/TestDominance.cpp

Modified: 
    mlir/include/mlir/Analysis/Dominance.h
    mlir/lib/Analysis/Dominance.cpp
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h
index 99e9eae23189..44d8e02a6670 100644
--- a/mlir/include/mlir/Analysis/Dominance.h
+++ b/mlir/include/mlir/Analysis/Dominance.h
@@ -34,12 +34,20 @@ template <bool IsPostDom> class DominanceInfoBase {
   /// Recalculate the dominance info.
   void recalculate(Operation *op);
 
+  /// Finds the nearest common dominator block for the two given blocks a
+  /// and b. If no common dominator can be found, this function will return
+  /// nullptr.
+  Block *findNearestCommonDominator(Block *a, Block *b) const;
+
   /// Get the root dominance node of the given region.
   DominanceInfoNode *getRootNode(Region *region) {
     assert(dominanceInfos.count(region) != 0);
     return dominanceInfos[region]->getRootNode();
   }
 
+  /// Return the dominance node from the Region containing block A.
+  DominanceInfoNode *getNode(Block *a);
+
 protected:
   using super = DominanceInfoBase<IsPostDom>;
 
@@ -82,9 +90,6 @@ class DominanceInfo : public detail::DominanceInfoBase</*IsPostDom=*/false> {
     return super::properlyDominates(a, b);
   }
 
-  /// Return the dominance node from the Region containing block A.
-  DominanceInfoNode *getNode(Block *a);
-
   /// Update the internal DFS numbers for the dominance nodes.
   void updateDFSNumbers();
 };

diff  --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp
index b2f0930cd81a..1b5f4cd917a2 100644
--- a/mlir/lib/Analysis/Dominance.cpp
+++ b/mlir/lib/Analysis/Dominance.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Analysis/Dominance.h"
 #include "mlir/IR/Operation.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/GenericDomTreeConstruction.h"
 
 using namespace mlir;
@@ -43,6 +44,99 @@ void DominanceInfoBase<IsPostDom>::recalculate(Operation *op) {
   });
 }
 
+/// Walks up the list of containers of the given block and calls the
+/// user-defined traversal function for every pair of a region and block that
+/// could be found during traversal. If the user-defined function returns true
+/// for a given pair, traverseAncestors will return the current block. Nullptr
+/// otherwise.
+template <typename FuncT>
+Block *traverseAncestors(Block *block, const FuncT &func) {
+  // Invoke the user-defined traversal function in the beginning for the current
+  // block.
+  if (func(block))
+    return block;
+
+  Region *region = block->getParent();
+  while (region) {
+    Operation *ancestor = region->getParentOp();
+    // If we have reached to top... return.
+    if (!ancestor || !(block = ancestor->getBlock()))
+      break;
+
+    // Update the nested region using the new ancestor block.
+    region = block->getParent();
+
+    // Invoke the user-defined traversal function and check whether we can
+    // already return.
+    if (func(block))
+      return block;
+  }
+  return nullptr;
+}
+
+/// Tries to update the given block references to live in the same region by
+/// exploring the relationship of both blocks with respect to their regions.
+static bool tryGetBlocksInSameRegion(Block *&a, Block *&b) {
+  // If both block do not live in the same region, we will have to check their
+  // parent operations.
+  if (a->getParent() == b->getParent())
+    return true;
+
+  // Iterate over all ancestors of a and insert them into the map. This allows
+  // for efficient lookups to find a commonly shared region.
+  llvm::SmallDenseMap<Region *, Block *, 4> ancestors;
+  traverseAncestors(a, [&](Block *block) {
+    ancestors[block->getParent()] = block;
+    return false;
+  });
+
+  // Try to find a common ancestor starting with regionB.
+  b = traverseAncestors(
+      b, [&](Block *block) { return ancestors.count(block->getParent()) > 0; });
+
+  // If there is no match, we will not be able to find a common dominator since
+  // both regions do not share a common parent region.
+  if (!b)
+    return false;
+
+  // We have found a common parent region. Update block a to refer to this
+  // region.
+  auto it = ancestors.find(b->getParent());
+  assert(it != ancestors.end());
+  a = it->second;
+  return true;
+}
+
+template <bool IsPostDom>
+Block *
+DominanceInfoBase<IsPostDom>::findNearestCommonDominator(Block *a,
+                                                         Block *b) const {
+  // If either a or b are null, then conservatively return nullptr.
+  if (!a || !b)
+    return nullptr;
+
+  // Try to find blocks that are in the same region.
+  if (!tryGetBlocksInSameRegion(a, b))
+    return nullptr;
+
+  // Get and verify dominance information of the common parent region.
+  Region *parentRegion = a->getParent();
+  auto infoAIt = dominanceInfos.find(parentRegion);
+  if (infoAIt == dominanceInfos.end())
+    return nullptr;
+
+  // Since the blocks live in the same region, we can rely on already
+  // existing dominance functionality.
+  return infoAIt->second->findNearestCommonDominator(a, b);
+}
+
+template <bool IsPostDom>
+DominanceInfoNode *DominanceInfoBase<IsPostDom>::getNode(Block *a) {
+  auto *region = a->getParent();
+  assert(dominanceInfos.count(region) != 0);
+  return dominanceInfos[region]->getNode(a);
+}
+
 /// Return true if the specified block A properly dominates block B.
 template <bool IsPostDom>
 bool DominanceInfoBase<IsPostDom>::properlyDominates(Block *a, Block *b) {
@@ -57,21 +151,17 @@ bool DominanceInfoBase<IsPostDom>::properlyDominates(Block *a, Block *b) {
   // If both blocks are not in the same region, 'a' properly dominates 'b' if
   // 'b' is defined in an operation region that (recursively) ends up being
   // dominated by 'a'. Walk up the list of containers enclosing B.
-  auto *regionA = a->getParent(), *regionB = b->getParent();
-  if (regionA != regionB) {
-    Operation *bAncestor;
-    do {
-      bAncestor = regionB->getParentOp();
-      // If 'bAncestor' is the top level region, then 'a' is a block that post
-      // dominates 'b'.
-      if (!bAncestor || !bAncestor->getBlock())
-        return IsPostDom;
-
-      regionB = bAncestor->getBlock()->getParent();
-    } while (regionA != regionB);
+  auto *regionA = a->getParent();
+  if (regionA != b->getParent()) {
+    b = traverseAncestors(
+        b, [&](Block *block) { return block->getParent() == regionA; });
+
+    // If we could not find a valid block b then it is either a not a dominator
+    // or a post dominator.
+    if (!b)
+      return IsPostDom;
 
     // Check to see if the ancestor of 'b' is the same block as 'a'.
-    b = bAncestor->getBlock();
     if (a == b)
       return true;
   }
@@ -132,12 +222,6 @@ bool DominanceInfo::properlyDominates(Value a, Operation *b) {
   return dominates(a.cast<BlockArgument>().getOwner(), b->getBlock());
 }
 
-DominanceInfoNode *DominanceInfo::getNode(Block *a) {
-  auto *region = a->getParent();
-  assert(dominanceInfos.count(region) != 0);
-  return dominanceInfos[region]->getNode(a);
-}
-
 void DominanceInfo::updateDFSNumbers() {
   for (auto &iter : dominanceInfos)
     iter.second->updateDFSNumbers();

diff  --git a/mlir/test/Analysis/test-dominance.mlir b/mlir/test/Analysis/test-dominance.mlir
new file mode 100644
index 000000000000..3e3678b40468
--- /dev/null
+++ b/mlir/test/Analysis/test-dominance.mlir
@@ -0,0 +1,207 @@
+// RUN: mlir-opt %s -test-print-dominance -split-input-file 2>&1 | FileCheck %s --dump-input-on-failure
+
+// CHECK-LABEL: Testing : func_condBranch
+func @func_condBranch(%cond : i1) {
+  cond_br %cond, ^bb1, ^bb2
+^bb1:
+  br ^exit
+^bb2:
+  br ^exit
+^exit:
+  return
+}
+// CHECK-LABEL: --- DominanceInfo ---
+// CHECK-NEXT: Nearest(0, 0) = 0
+// CHECK-NEXT: Nearest(0, 1) = 0
+// CHECK-NEXT: Nearest(0, 2) = 0
+// CHECK-NEXT: Nearest(0, 3) = 0
+// CHECK: Nearest(1, 0) = 0
+// CHECK-NEXT: Nearest(1, 1) = 1
+// CHECK-NEXT: Nearest(1, 2) = 0
+// CHECK-NEXT: Nearest(1, 3) = 0
+// CHECK: Nearest(2, 0) = 0
+// CHECK-NEXT: Nearest(2, 1) = 0
+// CHECK-NEXT: Nearest(2, 2) = 2
+// CHECK-NEXT: Nearest(2, 3) = 0
+// CHECK: Nearest(3, 0) = 0
+// CHECK-NEXT: Nearest(3, 1) = 0
+// CHECK-NEXT: Nearest(3, 2) = 0
+// CHECK-NEXT: Nearest(3, 3) = 3
+// CHECK-LABEL: --- PostDominanceInfo ---
+// CHECK-NEXT: Nearest(0, 0) = 0
+// CHECK-NEXT: Nearest(0, 1) = 3
+// CHECK-NEXT: Nearest(0, 2) = 3
+// CHECK-NEXT: Nearest(0, 3) = 3
+// CHECK: Nearest(1, 0) = 3
+// CHECK-NEXT: Nearest(1, 1) = 1
+// CHECK-NEXT: Nearest(1, 2) = 3
+// CHECK-NEXT: Nearest(1, 3) = 3
+// CHECK: Nearest(2, 0) = 3
+// CHECK-NEXT: Nearest(2, 1) = 3
+// CHECK-NEXT: Nearest(2, 2) = 2
+// CHECK-NEXT: Nearest(2, 3) = 3
+// CHECK: Nearest(3, 0) = 3
+// CHECK-NEXT: Nearest(3, 1) = 3
+// CHECK-NEXT: Nearest(3, 2) = 3
+// CHECK-NEXT: Nearest(3, 3) = 3
+
+// -----
+
+// CHECK-LABEL: Testing : func_loop
+func @func_loop(%arg0 : i32, %arg1 : i32) {
+  br ^loopHeader(%arg0 : i32)
+^loopHeader(%counter : i32):
+  %lessThan = cmpi "slt", %counter, %arg1 : i32
+  cond_br %lessThan, ^loopBody, ^exit
+^loopBody:
+  %const0 = constant 1 : i32
+  %inc = addi %counter, %const0 : i32
+  br ^loopHeader(%inc : i32)
+^exit:
+  return
+}
+// CHECK-LABEL: --- DominanceInfo ---
+// CHECK: Nearest(1, 0) = 0
+// CHECK-NEXT: Nearest(1, 1) = 1
+// CHECK-NEXT: Nearest(1, 2) = 1
+// CHECK-NEXT: Nearest(1, 3) = 1
+// CHECK: Nearest(2, 0) = 0
+// CHECK-NEXT: Nearest(2, 1) = 1
+// CHECK-NEXT: Nearest(2, 2) = 2
+// CHECK-NEXT: Nearest(2, 3) = 1
+// CHECK: Nearest(3, 0) = 0
+// CHECK-NEXT: Nearest(3, 1) = 1
+// CHECK-NEXT: Nearest(3, 2) = 1
+// CHECK-NEXT: Nearest(3, 3) = 3
+// CHECK-LABEL: --- PostDominanceInfo ---
+// CHECK: Nearest(1, 0) = 1
+// CHECK-NEXT: Nearest(1, 1) = 1
+// CHECK-NEXT: Nearest(1, 2) = 1
+// CHECK-NEXT: Nearest(1, 3) = 3
+// CHECK: Nearest(2, 0) = 1
+// CHECK-NEXT: Nearest(2, 1) = 1
+// CHECK-NEXT: Nearest(2, 2) = 2
+// CHECK-NEXT: Nearest(2, 3) = 3
+// CHECK: Nearest(3, 0) = 3
+// CHECK-NEXT: Nearest(3, 1) = 3
+// CHECK-NEXT: Nearest(3, 2) = 3
+// CHECK-NEXT: Nearest(3, 3) = 3
+
+// -----
+
+// CHECK-LABEL: Testing : nested_region
+func @nested_region(%arg0 : index, %arg1 : index, %arg2 : index) {
+  loop.for %arg3 = %arg0 to %arg1 step %arg2 { }
+  return
+}
+
+// CHECK-LABEL: --- DominanceInfo ---
+// CHECK-NEXT: Nearest(0, 0) = 0
+// CHECK-NEXT: Nearest(0, 1) = 1
+// CHECK: Nearest(1, 0) = 1
+// CHECK-NEXT: Nearest(1, 1) = 1
+// CHECK-LABEL: --- PostDominanceInfo ---
+// CHECK-NEXT: Nearest(0, 0) = 0
+// CHECK-NEXT: Nearest(0, 1) = 1
+// CHECK: Nearest(1, 0) = 1
+// CHECK-NEXT: Nearest(1, 1) = 1
+
+// -----
+
+// CHECK-LABEL: Testing : nested_region2
+func @nested_region2(%arg0 : index, %arg1 : index, %arg2 : index) {
+  loop.for %arg3 = %arg0 to %arg1 step %arg2 {
+    loop.for %arg4 = %arg0 to %arg1 step %arg2 {
+      loop.for %arg5 = %arg0 to %arg1 step %arg2 { }
+    }
+  }
+  return
+}
+// CHECK-LABEL: --- DominanceInfo ---
+// CHECK: Nearest(1, 0) = 1
+// CHECK-NEXT: Nearest(1, 1) = 1
+// CHECK-NEXT: Nearest(1, 2) = 2
+// CHECK-NEXT: Nearest(1, 3) = 3
+// CHECK: Nearest(2, 0) = 2
+// CHECK-NEXT: Nearest(2, 1) = 2
+// CHECK-NEXT: Nearest(2, 2) = 2
+// CHECK-NEXT: Nearest(2, 3) = 3
+// CHECK: Nearest(3, 0) = 3
+// CHECK-NEXT: Nearest(3, 1) = 3
+// CHECK-NEXT: Nearest(3, 2) = 3
+// CHECK-NEXT: Nearest(3, 3) = 3
+// CHECK-LABEL: --- PostDominanceInfo ---
+// CHECK-NEXT: Nearest(0, 0) = 0
+// CHECK-NEXT: Nearest(0, 1) = 1
+// CHECK-NEXT: Nearest(0, 2) = 2
+// CHECK-NEXT: Nearest(0, 3) = 3
+// CHECK: Nearest(1, 0) = 1
+// CHECK-NEXT: Nearest(1, 1) = 1
+// CHECK-NEXT: Nearest(1, 2) = 2
+// CHECK-NEXT: Nearest(1, 3) = 3
+// CHECK: Nearest(2, 0) = 2
+// CHECK-NEXT: Nearest(2, 1) = 2
+// CHECK-NEXT: Nearest(2, 2) = 2
+// CHECK-NEXT: Nearest(2, 3) = 3
+
+// -----
+
+// CHECK-LABEL: Testing : func_loop_nested_region
+func @func_loop_nested_region(
+  %arg0 : i32,
+  %arg1 : i32,
+  %arg2 : index,
+  %arg3 : index,
+  %arg4 : index) {
+  br ^loopHeader(%arg0 : i32)
+^loopHeader(%counter : i32):
+  %lessThan = cmpi "slt", %counter, %arg1 : i32
+  cond_br %lessThan, ^loopBody, ^exit
+^loopBody:
+  %const0 = constant 1 : i32
+  %inc = addi %counter, %const0 : i32
+  loop.for %arg5 = %arg2 to %arg3 step %arg4 {
+    loop.for %arg6 = %arg2 to %arg3 step %arg4 { }
+  }
+  br ^loopHeader(%inc : i32)
+^exit:
+  return
+}
+// CHECK-LABEL: --- DominanceInfo ---
+// CHECK: Nearest(2, 0) = 0
+// CHECK-NEXT: Nearest(2, 1) = 1
+// CHECK-NEXT: Nearest(2, 2) = 2
+// CHECK-NEXT: Nearest(2, 3) = 2
+// CHECK-NEXT: Nearest(2, 4) = 2
+// CHECK-NEXT: Nearest(2, 5) = 1
+// CHECK: Nearest(3, 0) = 0
+// CHECK-NEXT: Nearest(3, 1) = 1
+// CHECK-NEXT: Nearest(3, 2) = 2
+// CHECK-NEXT: Nearest(3, 3) = 3
+// CHECK-NEXT: Nearest(3, 4) = 4
+// CHECK-NEXT: Nearest(3, 5) = 1
+// CHECK: Nearest(4, 0) = 0
+// CHECK-NEXT: Nearest(4, 1) = 1
+// CHECK-NEXT: Nearest(4, 2) = 2
+// CHECK-NEXT: Nearest(4, 3) = 4
+// CHECK-NEXT: Nearest(4, 4) = 4
+// CHECK-NEXT: Nearest(4, 5) = 1
+// CHECK-LABEL: --- PostDominanceInfo ---
+// CHECK: Nearest(2, 0) = 1
+// CHECK-NEXT: Nearest(2, 1) = 1
+// CHECK-NEXT: Nearest(2, 2) = 2
+// CHECK-NEXT: Nearest(2, 3) = 2
+// CHECK-NEXT: Nearest(2, 4) = 2
+// CHECK-NEXT: Nearest(2, 5) = 5
+// CHECK: Nearest(3, 0) = 1
+// CHECK-NEXT: Nearest(3, 1) = 1
+// CHECK-NEXT: Nearest(3, 2) = 2
+// CHECK-NEXT: Nearest(3, 3) = 3
+// CHECK-NEXT: Nearest(3, 4) = 4
+// CHECK-NEXT: Nearest(3, 5) = 5
+// CHECK: Nearest(4, 0) = 1
+// CHECK-NEXT: Nearest(4, 1) = 1
+// CHECK-NEXT: Nearest(4, 2) = 2
+// CHECK-NEXT: Nearest(4, 3) = 4
+// CHECK-NEXT: Nearest(4, 4) = 4
+// CHECK-NEXT: Nearest(4, 5) = 5
\ No newline at end of file

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index b4726439c83f..b02848eed4cc 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_llvm_library(MLIRTestTransforms
   TestCallGraph.cpp
   TestConstantFold.cpp
   TestConvertGPUKernelToCubin.cpp
+  TestDominance.cpp
   TestLoopFusion.cpp
   TestGpuMemoryPromotion.cpp
   TestGpuParallelLoopMapping.cpp

diff  --git a/mlir/test/lib/Transforms/TestDominance.cpp b/mlir/test/lib/Transforms/TestDominance.cpp
new file mode 100644
index 000000000000..784bb1f40564
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestDominance.cpp
@@ -0,0 +1,90 @@
+//===- TestDominance.cpp - Test dominance construction and information
+//-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for constructing and resolving dominance
+// information.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Helper class to print dominance information.
+class DominanceTest {
+public:
+  /// Constructs a new test instance using the given operation.
+  DominanceTest(Operation *operation) : operation(operation) {
+    // Create unique ids for each block.
+    operation->walk([&](Operation *nested) {
+      if (blockIds.count(nested->getBlock()) > 0)
+        return;
+      blockIds.insert({nested->getBlock(), blockIds.size()});
+    });
+  }
+
+  /// Prints dominance information of all blocks.
+  template <typename DominanceT>
+  void printDominance(DominanceT &dominanceInfo) {
+    DenseSet<Block *> parentVisited;
+    operation->walk([&](Operation *op) {
+      Block *block = op->getBlock();
+      if (!parentVisited.insert(block).second)
+        return;
+
+      DenseSet<Block *> visited;
+      operation->walk([&](Operation *nested) {
+        Block *nestedBlock = nested->getBlock();
+        if (!visited.insert(nestedBlock).second)
+          return;
+        llvm::errs() << "Nearest(" << blockIds[block] << ", "
+                     << blockIds[nestedBlock] << ") = ";
+        Block *dom =
+            dominanceInfo.findNearestCommonDominator(block, nestedBlock);
+        if (dom)
+          llvm::errs() << blockIds[dom];
+        else
+          llvm::errs() << "<no dom>";
+        llvm::errs() << "\n";
+      });
+    });
+  }
+
+private:
+  Operation *operation;
+  DenseMap<Block *, size_t> blockIds;
+};
+
+struct TestDominancePass : public FunctionPass<TestDominancePass> {
+
+  void runOnFunction() override {
+    llvm::errs() << "Testing : " << getFunction().getName() << "\n";
+    DominanceTest dominanceTest(getFunction());
+
+    // Print dominance information.
+    llvm::errs() << "--- DominanceInfo ---\n";
+    dominanceTest.printDominance(getAnalysis<DominanceInfo>());
+
+    llvm::errs() << "--- PostDominanceInfo ---\n";
+    dominanceTest.printDominance(getAnalysis<PostDominanceInfo>());
+  }
+};
+
+} // end anonymous namespace
+
+namespace mlir {
+void registerTestDominancePass() {
+  PassRegistration<TestDominancePass>(
+      "test-print-dominance",
+      "Print the dominance information for multiple regions.");
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index ff0f49f987b6..9234b504e2e8 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -42,6 +42,7 @@ void registerTestAllReduceLoweringPass();
 void registerTestCallGraphPass();
 void registerTestConstantFold();
 void registerTestConvertGPUKernelToCubinPass();
+void registerTestDominancePass();
 void registerTestFunc();
 void registerTestGpuMemoryPromotionPass();
 void registerTestLinalgTransforms();
@@ -100,6 +101,7 @@ void registerTestPasses() {
 #if MLIR_CUDA_CONVERSIONS_ENABLED
   registerTestConvertGPUKernelToCubinPass();
 #endif
+  registerTestDominancePass();
   registerTestFunc();
   registerTestGpuMemoryPromotionPass();
   registerTestLinalgTransforms();


        


More information about the Mlir-commits mailing list