[llvm-branch-commits] [mlir] [mlir][IR] Add `InsertPoint::after(ValueRange)` (PR #114940)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Nov 11 20:44:53 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/114940

>From 8a49434df62e394cd109f0189349b4d28dafa525 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 9 Nov 2024 12:29:16 +0100
Subject: [PATCH] [mlir][IR] Add `OpBuilder::setInsertionPointAfterValues`

---
 mlir/include/mlir/IR/Builders.h  | 14 +++++++++++
 mlir/include/mlir/IR/Dominance.h | 23 ++++++++++++++++++
 mlir/lib/IR/Builders.cpp         | 41 ++++++++++++++++++++++++++++++++
 3 files changed, 78 insertions(+)

diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 6fb71ccefda151..7ef03b87179523 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -16,6 +16,7 @@
 namespace mlir {
 
 class AffineExpr;
+class PostDominanceInfo;
 class IRMapping;
 class UnknownLoc;
 class FileLineColLoc;
@@ -341,6 +342,19 @@ class OpBuilder : public Builder {
     InsertPoint(Block *insertBlock, Block::iterator insertPt)
         : block(insertBlock), point(insertPt) {}
 
+    /// Compute an insertion point to a place that post-dominates the
+    /// definitions of all given values. Returns an "empty" insertion point if
+    /// no such insertion point exists.
+    ///
+    /// There may be multiple suitable insertion points. This function chooses
+    /// an insertion right after one of the given values.
+    ///
+    /// Note: Some of the given values may already have gone out of scope at the
+    /// selected insertion point. (E.g., because they are defined in a nested
+    /// region or because they are not visible in an IsolatedFromAbove region.)
+    static InsertPoint after(ArrayRef<Value> values,
+                             const PostDominanceInfo &domInfo);
+
     /// Returns true if this insert point is set.
     bool isSet() const { return (block != nullptr); }
 
diff --git a/mlir/include/mlir/IR/Dominance.h b/mlir/include/mlir/IR/Dominance.h
index 63504cad211a4d..be2dcec380b6cc 100644
--- a/mlir/include/mlir/IR/Dominance.h
+++ b/mlir/include/mlir/IR/Dominance.h
@@ -187,6 +187,17 @@ class DominanceInfo : public detail::DominanceInfoBase</*IsPostDom=*/false> {
   /// dominance" of ops, the single block is considered to properly dominate
   /// itself in a graph region.
   bool properlyDominates(Block *a, Block *b) const;
+
+  bool properlyDominantes(Block *aBlock, Block::iterator aIt, Block *bBlock,
+                          Block::iterator bIt, bool enclosingOk = true) const {
+    return super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk);
+  }
+
+  bool dominantes(Block *aBlock, Block::iterator aIt, Block *bBlock,
+                  Block::iterator bIt, bool enclosingOk = true) const {
+    return (aBlock == bBlock && aIt == bIt) ||
+           super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk);
+  }
 };
 
 /// A class for computing basic postdominance information.
@@ -210,6 +221,18 @@ class PostDominanceInfo : public detail::DominanceInfoBase</*IsPostDom=*/true> {
   bool postDominates(Block *a, Block *b) const {
     return a == b || properlyPostDominates(a, b);
   }
+
+  bool properlyPostDominantes(Block *aBlock, Block::iterator aIt, Block *bBlock,
+                              Block::iterator bIt,
+                              bool enclosingOk = true) const {
+    return super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk);
+  }
+
+  bool postDominantes(Block *aBlock, Block::iterator aIt, Block *bBlock,
+                      Block::iterator bIt, bool enclosingOk = true) const {
+    return (aBlock == bBlock && aIt == bIt) ||
+           super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk);
+  }
 };
 
 } // namespace mlir
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 5397fbabc5c95e..4714c3cace6c78 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Matchers.h"
@@ -641,3 +642,43 @@ void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
 void OpBuilder::cloneRegionBefore(Region &region, Block *before) {
   cloneRegionBefore(region, *before->getParent(), before->getIterator());
 }
+
+OpBuilder::InsertPoint
+OpBuilder::InsertPoint::after(ArrayRef<Value> values,
+                              const PostDominanceInfo &domInfo) {
+  // Helper function that computes the point after v's definition.
+  auto computeAfterIp = [](Value v) -> std::pair<Block *, Block::iterator> {
+    if (auto blockArg = dyn_cast<BlockArgument>(v))
+      return std::make_pair(blockArg.getOwner(), blockArg.getOwner()->begin());
+    Operation *op = v.getDefiningOp();
+    return std::make_pair(op->getBlock(), ++op->getIterator());
+  };
+
+  // Compute the insertion point after the first value is defined.
+  assert(!values.empty() && "expected at least one Value");
+  auto [block, blockIt] = computeAfterIp(values.front());
+
+  // Check the other values one-by-one and update the insertion point if
+  // needed.
+  for (Value v : values.drop_front()) {
+    auto [candidateBlock, candidateBlockIt] = computeAfterIp(v);
+    if (domInfo.postDominantes(candidateBlock, candidateBlockIt, block,
+                               blockIt)) {
+      // The point after v's definition post-dominates the current (and all
+      // previous) insertion points. Note: Post-dominance is transitive.
+      block = candidateBlock;
+      blockIt = candidateBlockIt;
+      continue;
+    }
+
+    if (!domInfo.postDominantes(block, blockIt, candidateBlock,
+                                candidateBlockIt)) {
+      // The point after v's definition and the current insertion point do not
+      // post-dominate each other. Therefore, there is no insertion point that
+      // post-dominates all values.
+      return InsertPoint(nullptr, Block::iterator());
+    }
+  }
+
+  return InsertPoint(block, blockIt);
+}



More information about the llvm-branch-commits mailing list