[Mlir-commits] [mlir] [mlir][IR] Add `Block::isReachable` helper function (PR #114928)

Matthias Springer llvmlistbot at llvm.org
Wed Nov 6 22:30:17 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/114928

>From a230a5254e5a29ec1efcf77ddee95c664ed3b41b Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 5 Nov 2024 06:29:46 +0100
Subject: [PATCH 1/2] [mlir][IR] Add `Block::isReachable` helper function

Add a new helper function `isReachable` to `Block`. This function traverses all successors of a block to determine if another block is reachable from the current block.

This functionality has been reimplemented in multiple places.
---
 mlir/include/mlir/IR/Block.h                  |  5 ++++
 .../Transforms/OneShotAnalysis.cpp            | 23 ++-----------------
 .../Transforms/VectorTransferOpTransforms.cpp | 15 +-----------
 mlir/lib/IR/Block.cpp                         | 22 +++++++++++++++++-
 4 files changed, 29 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 536cbf9018e898..37fa8dfe90bf0d 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -264,6 +264,11 @@ class alignas(8) Block : public IRObjectWithUseList<BlockOperand>,
   succ_iterator succ_end() { return getSuccessors().end(); }
   SuccessorRange getSuccessors() { return SuccessorRange(this); }
 
+  /// Return "true" if there is a path from this block to the given block
+  /// (according to the successors relationship). Both blocks must be in the
+  /// same region. Paths that contain a block from `except` do not count.
+  bool isReachable(Block *other, ArrayRef<Block *> except = {});
+
   //===--------------------------------------------------------------------===//
   // Walkers
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 829e954d53b259..d1e6acef324fbd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -273,25 +273,6 @@ static bool happensBefore(Operation *a, Operation *b,
   return false;
 }
 
-static bool isReachable(Block *from, Block *to, ArrayRef<Block *> except) {
-  DenseSet<Block *> visited;
-  SmallVector<Block *> worklist;
-  for (Block *succ : from->getSuccessors())
-    worklist.push_back(succ);
-  while (!worklist.empty()) {
-    Block *next = worklist.pop_back_val();
-    if (llvm::is_contained(except, next))
-      continue;
-    if (next == to)
-      return true;
-    if (!visited.insert(next).second)
-      continue;
-    for (Block *succ : next->getSuccessors())
-      worklist.push_back(succ);
-  }
-  return false;
-}
-
 /// Return `true` if op dominance can be used to rule out a read-after-write
 /// conflicts based on the ordering of ops. Returns `false` if op dominance
 /// cannot be used to due region-based loops.
@@ -427,8 +408,8 @@ static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite,
   Block *writeBlock = uWrite->getOwner()->getBlock();
   for (Value def : definitions) {
     Block *defBlock = def.getParentBlock();
-    if (isReachable(readBlock, writeBlock, {defBlock}) &&
-        isReachable(writeBlock, readBlock, {defBlock}))
+    if (readBlock->isReachable(writeBlock, {defBlock}) &&
+        writeBlock->isReachable(readBlock, {defBlock}))
       return false;
   }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 3a30382114c8dc..bd5f06a3b46d42 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -73,20 +73,7 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
   // Simple case where the start op dominate the destination.
   if (dominators.dominates(start, dest))
     return true;
-  Block *startBlock = start->getBlock();
-  Block *destBlock = dest->getBlock();
-  SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
-                                    startBlock->succ_end());
-  SmallPtrSet<Block *, 32> visited;
-  while (!worklist.empty()) {
-    Block *bb = worklist.pop_back_val();
-    if (!visited.insert(bb).second)
-      continue;
-    if (dominators.dominates(bb, destBlock))
-      return true;
-    worklist.append(bb->succ_begin(), bb->succ_end());
-  }
-  return false;
+  return start->getBlock()->isReachable(dest->getBlock());
 }
 
 /// For transfer_write to overwrite fully another transfer_write must:
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 65099f8ff15a6f..5ae98faaaba12c 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -7,9 +7,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/Block.h"
+
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Operation.h"
 #include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -331,7 +334,7 @@ unsigned PredecessorIterator::getSuccessorIndex() const {
 }
 
 //===----------------------------------------------------------------------===//
-// SuccessorRange
+// Successors
 //===----------------------------------------------------------------------===//
 
 SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {}
@@ -349,6 +352,23 @@ SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() {
     base = term->getBlockOperands().data();
 }
 
+bool Block::isReachable(Block *other, ArrayRef<Block *> except) {
+  assert(getParent() == other->getParent() && "expected same region");
+  SmallVector<Block *> worklist(succ_begin(), succ_end());
+  SmallPtrSet<Block *, 16> visited;
+  while (!worklist.empty()) {
+    Block *next = worklist.pop_back_val();
+    if (llvm::is_contained(except, next))
+      continue;
+    if (next == other)
+      return true;
+    if (!visited.insert(next).second)
+      continue;
+    worklist.append(next->succ_begin(), next->succ_end());
+  }
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // BlockRange
 //===----------------------------------------------------------------------===//

>From ecb21fac24d9043663df8403663af9aa9778b026 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 7 Nov 2024 07:29:56 +0100
Subject: [PATCH 2/2] address comments

---
 mlir/include/mlir/IR/Block.h |  8 +++++++-
 mlir/lib/IR/Block.cpp        | 13 ++++++++-----
 2 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 37fa8dfe90bf0d..d63cae597a2325 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -16,6 +16,8 @@
 #include "mlir/IR/BlockSupport.h"
 #include "mlir/IR/Visitors.h"
 
+#include "llvm/ADT/SmallPtrSet.h"
+
 namespace llvm {
 class BitVector;
 class raw_ostream;
@@ -267,7 +269,11 @@ class alignas(8) Block : public IRObjectWithUseList<BlockOperand>,
   /// Return "true" if there is a path from this block to the given block
   /// (according to the successors relationship). Both blocks must be in the
   /// same region. Paths that contain a block from `except` do not count.
-  bool isReachable(Block *other, ArrayRef<Block *> except = {});
+  /// This function returns "false" if `other` is in `except`.
+  ///
+  /// Note: This function performs a block graph traversal and its complexity
+  /// linear in the number of blocks in the parent region.
+  bool isReachable(Block *other, SmallPtrSet<Block *, 16> &&except = {});
 
   //===--------------------------------------------------------------------===//
   // Walkers
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 5ae98faaaba12c..4b1568219fb376 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -352,17 +352,20 @@ SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() {
     base = term->getBlockOperands().data();
 }
 
-bool Block::isReachable(Block *other, ArrayRef<Block *> except) {
+bool Block::isReachable(Block *other, SmallPtrSet<Block *, 16> &&except) {
   assert(getParent() == other->getParent() && "expected same region");
+  if (except.contains(other)) {
+    // Fast path: If `other` is in the `except` set, there can be no path from
+    // "this" to `other` (that does not pass through an excluded block).
+    return false;
+  }
   SmallVector<Block *> worklist(succ_begin(), succ_end());
-  SmallPtrSet<Block *, 16> visited;
   while (!worklist.empty()) {
     Block *next = worklist.pop_back_val();
-    if (llvm::is_contained(except, next))
-      continue;
     if (next == other)
       return true;
-    if (!visited.insert(next).second)
+    // Note: `except` keeps track of already visited blocks.
+    if (!except.insert(next).second)
       continue;
     worklist.append(next->succ_begin(), next->succ_end());
   }



More information about the Mlir-commits mailing list