[Mlir-commits] [mlir] [mlir][IR] Add `Block::isReachable` helper function (PR #114928)
Matthias Springer
llvmlistbot at llvm.org
Wed Nov 6 22:30:17 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/114928
>From a230a5254e5a29ec1efcf77ddee95c664ed3b41b Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 5 Nov 2024 06:29:46 +0100
Subject: [PATCH 1/2] [mlir][IR] Add `Block::isReachable` helper function
Add a new helper function `isReachable` to `Block`. This function traverses all successors of a block to determine if another block is reachable from the current block.
This functionality has been reimplemented in multiple places.
---
mlir/include/mlir/IR/Block.h | 5 ++++
.../Transforms/OneShotAnalysis.cpp | 23 ++-----------------
.../Transforms/VectorTransferOpTransforms.cpp | 15 +-----------
mlir/lib/IR/Block.cpp | 22 +++++++++++++++++-
4 files changed, 29 insertions(+), 36 deletions(-)
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 536cbf9018e898..37fa8dfe90bf0d 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -264,6 +264,11 @@ class alignas(8) Block : public IRObjectWithUseList<BlockOperand>,
succ_iterator succ_end() { return getSuccessors().end(); }
SuccessorRange getSuccessors() { return SuccessorRange(this); }
+ /// Return "true" if there is a path from this block to the given block
+ /// (according to the successors relationship). Both blocks must be in the
+ /// same region. Paths that contain a block from `except` do not count.
+ bool isReachable(Block *other, ArrayRef<Block *> except = {});
+
//===--------------------------------------------------------------------===//
// Walkers
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 829e954d53b259..d1e6acef324fbd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -273,25 +273,6 @@ static bool happensBefore(Operation *a, Operation *b,
return false;
}
-static bool isReachable(Block *from, Block *to, ArrayRef<Block *> except) {
- DenseSet<Block *> visited;
- SmallVector<Block *> worklist;
- for (Block *succ : from->getSuccessors())
- worklist.push_back(succ);
- while (!worklist.empty()) {
- Block *next = worklist.pop_back_val();
- if (llvm::is_contained(except, next))
- continue;
- if (next == to)
- return true;
- if (!visited.insert(next).second)
- continue;
- for (Block *succ : next->getSuccessors())
- worklist.push_back(succ);
- }
- return false;
-}
-
/// Return `true` if op dominance can be used to rule out a read-after-write
/// conflicts based on the ordering of ops. Returns `false` if op dominance
/// cannot be used to due region-based loops.
@@ -427,8 +408,8 @@ static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite,
Block *writeBlock = uWrite->getOwner()->getBlock();
for (Value def : definitions) {
Block *defBlock = def.getParentBlock();
- if (isReachable(readBlock, writeBlock, {defBlock}) &&
- isReachable(writeBlock, readBlock, {defBlock}))
+ if (readBlock->isReachable(writeBlock, {defBlock}) &&
+ writeBlock->isReachable(readBlock, {defBlock}))
return false;
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 3a30382114c8dc..bd5f06a3b46d42 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -73,20 +73,7 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
// Simple case where the start op dominate the destination.
if (dominators.dominates(start, dest))
return true;
- Block *startBlock = start->getBlock();
- Block *destBlock = dest->getBlock();
- SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
- startBlock->succ_end());
- SmallPtrSet<Block *, 32> visited;
- while (!worklist.empty()) {
- Block *bb = worklist.pop_back_val();
- if (!visited.insert(bb).second)
- continue;
- if (dominators.dominates(bb, destBlock))
- return true;
- worklist.append(bb->succ_begin(), bb->succ_end());
- }
- return false;
+ return start->getBlock()->isReachable(dest->getBlock());
}
/// For transfer_write to overwrite fully another transfer_write must:
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 65099f8ff15a6f..5ae98faaaba12c 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -7,9 +7,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Block.h"
+
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -331,7 +334,7 @@ unsigned PredecessorIterator::getSuccessorIndex() const {
}
//===----------------------------------------------------------------------===//
-// SuccessorRange
+// Successors
//===----------------------------------------------------------------------===//
SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {}
@@ -349,6 +352,23 @@ SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() {
base = term->getBlockOperands().data();
}
+bool Block::isReachable(Block *other, ArrayRef<Block *> except) {
+ assert(getParent() == other->getParent() && "expected same region");
+ SmallVector<Block *> worklist(succ_begin(), succ_end());
+ SmallPtrSet<Block *, 16> visited;
+ while (!worklist.empty()) {
+ Block *next = worklist.pop_back_val();
+ if (llvm::is_contained(except, next))
+ continue;
+ if (next == other)
+ return true;
+ if (!visited.insert(next).second)
+ continue;
+ worklist.append(next->succ_begin(), next->succ_end());
+ }
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// BlockRange
//===----------------------------------------------------------------------===//
>From ecb21fac24d9043663df8403663af9aa9778b026 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 7 Nov 2024 07:29:56 +0100
Subject: [PATCH 2/2] address comments
---
mlir/include/mlir/IR/Block.h | 8 +++++++-
mlir/lib/IR/Block.cpp | 13 ++++++++-----
2 files changed, 15 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 37fa8dfe90bf0d..d63cae597a2325 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -16,6 +16,8 @@
#include "mlir/IR/BlockSupport.h"
#include "mlir/IR/Visitors.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
namespace llvm {
class BitVector;
class raw_ostream;
@@ -267,7 +269,11 @@ class alignas(8) Block : public IRObjectWithUseList<BlockOperand>,
/// Return "true" if there is a path from this block to the given block
/// (according to the successors relationship). Both blocks must be in the
/// same region. Paths that contain a block from `except` do not count.
- bool isReachable(Block *other, ArrayRef<Block *> except = {});
+ /// This function returns "false" if `other` is in `except`.
+ ///
+ /// Note: This function performs a block graph traversal and its complexity
+ /// linear in the number of blocks in the parent region.
+ bool isReachable(Block *other, SmallPtrSet<Block *, 16> &&except = {});
//===--------------------------------------------------------------------===//
// Walkers
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 5ae98faaaba12c..4b1568219fb376 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -352,17 +352,20 @@ SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() {
base = term->getBlockOperands().data();
}
-bool Block::isReachable(Block *other, ArrayRef<Block *> except) {
+bool Block::isReachable(Block *other, SmallPtrSet<Block *, 16> &&except) {
assert(getParent() == other->getParent() && "expected same region");
+ if (except.contains(other)) {
+ // Fast path: If `other` is in the `except` set, there can be no path from
+ // "this" to `other` (that does not pass through an excluded block).
+ return false;
+ }
SmallVector<Block *> worklist(succ_begin(), succ_end());
- SmallPtrSet<Block *, 16> visited;
while (!worklist.empty()) {
Block *next = worklist.pop_back_val();
- if (llvm::is_contained(except, next))
- continue;
if (next == other)
return true;
- if (!visited.insert(next).second)
+ // Note: `except` keeps track of already visited blocks.
+ if (!except.insert(next).second)
continue;
worklist.append(next->succ_begin(), next->succ_end());
}
More information about the Mlir-commits
mailing list