[Mlir-commits] [mlir] 9297b9f - [mlir][Transforms][NFC] Improve builder/listener API of OperationFolder

Matthias Springer llvmlistbot at llvm.org
Wed Mar 22 01:24:57 PDT 2023


Author: Matthias Springer
Date: 2023-03-22T09:24:47+01:00
New Revision: 9297b9f8eeecc5ea6571cf45985ba77bc2960427

URL: https://github.com/llvm/llvm-project/commit/9297b9f8eeecc5ea6571cf45985ba77bc2960427
DIFF: https://github.com/llvm/llvm-project/commit/9297b9f8eeecc5ea6571cf45985ba77bc2960427.diff

LOG: [mlir][Transforms][NFC] Improve builder/listener API of OperationFolder

The constructor of `OperationFolder` takes a listener. Therefore, the remaining API should not take any builder/rewriters. This could lead to double notifications in case a listener is attached to the builder/rewriter.

As an internal cleanup, `OperationFolder` now has an `IRRewriter` instead of a `RewriterBase::Listener`. In most cases, `OperationFolder` no longer has to notify/deal with listeners. This is done by the rewriter.

Differential Revision: https://reviews.llvm.org/D146134

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/FoldUtils.h
    mlir/lib/Transforms/SCCP.cpp
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/test/lib/Transforms/TestIntRangeInference.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 9c4790c03120..600ace488273 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -624,7 +624,9 @@ class RewriterBase : public OpBuilder {
 
 protected:
   /// Initialize the builder.
-  explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx) {}
+  explicit RewriterBase(MLIRContext *ctx,
+                        OpBuilder::Listener *listener = nullptr)
+      : OpBuilder(ctx, listener) {}
   explicit RewriterBase(const OpBuilder &otherBuilder)
       : OpBuilder(otherBuilder) {}
   virtual ~RewriterBase();
@@ -648,7 +650,8 @@ class RewriterBase : public OpBuilder {
 /// such as a `PatternRewriter`, is not available.
 class IRRewriter : public RewriterBase {
 public:
-  explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
+  explicit IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr)
+      : RewriterBase(ctx, listener) {}
   explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
 };
 

diff  --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index a6dc18369e77..2600da361496 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -32,8 +32,8 @@ class Value;
 /// generated along the way.
 class OperationFolder {
 public:
-  OperationFolder(MLIRContext *ctx, RewriterBase::Listener *listener = nullptr)
-      : interfaces(ctx), listener(listener) {}
+  OperationFolder(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr)
+      : interfaces(ctx), rewriter(ctx, listener) {}
 
   /// Tries to perform folding on the given `op`, including unifying
   /// deduplicated constants. If successful, replaces `op`'s uses with
@@ -61,10 +61,11 @@ class OperationFolder {
   /// Clear out any constants cached inside of the folder.
   void clear();
 
-  /// Get or create a constant using the given builder. On success this returns
-  /// the constant operation, nullptr otherwise.
-  Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
-                            Attribute value, Type type, Location loc);
+  /// Get or create a constant for use in the specified block. The constant may
+  /// be created in a parent block. On success this returns the constant
+  /// operation, nullptr otherwise.
+  Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value,
+                            Type type, Location loc);
 
 private:
   /// This map keeps track of uniqued constants by dialect, attribute, and type.
@@ -74,29 +75,25 @@ class OperationFolder {
   using ConstantMap =
       DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
 
-  /// Erase the given operation and notify the listener.
-  void eraseOp(Operation *op);
-
   /// Returns true if the given operation is an already folded constant that is
   /// owned by this folder.
   bool isFolderOwnedConstant(Operation *op) const;
 
   /// Tries to perform folding on the given `op`. If successful, populates
   /// `results` with the results of the folding.
-  LogicalResult tryToFold(OpBuilder &builder, Operation *op,
-                          SmallVectorImpl<Value> &results);
+  LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value> &results);
 
-  /// Try to process a set of fold results, generating constants as necessary.
-  /// Populates `results` on success, otherwise leaves it unchanged.
-  LogicalResult processFoldResults(OpBuilder &builder, Operation *op,
+  /// Try to process a set of fold results. Populates `results` on success,
+  /// otherwise leaves it unchanged.
+  LogicalResult processFoldResults(Operation *op,
                                    SmallVectorImpl<Value> &results,
                                    ArrayRef<OpFoldResult> foldResults);
 
   /// Try to get or create a new constant entry. On success this returns the
   /// constant operation, nullptr otherwise.
   Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,
-                                    Dialect *dialect, OpBuilder &builder,
-                                    Attribute value, Type type, Location loc);
+                                    Dialect *dialect, Attribute value,
+                                    Type type, Location loc);
 
   /// A mapping between an insertion region and the constants that have been
   /// created within it.
@@ -109,8 +106,8 @@ class OperationFolder {
   /// A collection of dialect folder interfaces.
   DialectInterfaceCollection<DialectFoldInterface> interfaces;
 
-  /// An optional listener that is notified of all IR changes.
-  RewriterBase::Listener *listener = nullptr;
+  /// A rewriter that performs all IR modifications.
+  IRRewriter rewriter;
 };
 
 } // namespace mlir

diff  --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index b32173a3a981..14435b37acc9 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -51,9 +51,9 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver,
 
   // Attempt to materialize a constant for the given value.
   Dialect *dialect = latticeValue.getConstantDialect();
-  Value constant = folder.getOrCreateConstant(builder, dialect,
-                                              latticeValue.getConstantValue(),
-                                              value.getType(), value.getLoc());
+  Value constant = folder.getOrCreateConstant(
+      builder.getInsertionBlock(), dialect, latticeValue.getConstantValue(),
+      value.getType(), value.getLoc());
   if (!constant)
     return failure();
 

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 22a488efb37e..827c0ad4290b 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -84,26 +84,25 @@ LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
 
   // Try to fold the operation.
   SmallVector<Value, 8> results;
-  OpBuilder builder(op, listener);
-  if (failed(tryToFold(builder, op, results)))
+  if (failed(tryToFold(op, results)))
     return failure();
 
   // Check to see if the operation was just updated in place.
   if (results.empty()) {
     if (inPlaceUpdate)
       *inPlaceUpdate = true;
-    if (listener)
-      listener->notifyOperationModified(op);
+    if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>(
+            rewriter.getListener())) {
+      // Folding API does not notify listeners, so we have to notify manually.
+      rewriteListener->notifyOperationModified(op);
+    }
     return success();
   }
 
   // Constant folding succeeded. Replace all of the result values and erase the
   // operation.
-  if (listener)
-    listener->notifyOperationReplaced(op, results);
-  for (unsigned i = 0, e = results.size(); i != e; ++i)
-    op->getResult(i).replaceAllUsesWith(results[i]);
-  eraseOp(op);
+  notifyRemoval(op);
+  rewriter.replaceOp(op, results);
   return success();
 }
 
@@ -141,10 +140,8 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
 
   // If there is an existing constant, replace `op`.
   if (folderConstOp) {
-    if (listener)
-      listener->notifyOperationReplaced(op, folderConstOp->getResults());
-    op->replaceAllUsesWith(folderConstOp);
-    eraseOp(op);
+    notifyRemoval(op);
+    rewriter.replaceOp(op, folderConstOp->getResults());
     return false;
   }
 
@@ -162,13 +159,6 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
   return true;
 }
 
-void OperationFolder::eraseOp(Operation *op) {
-  notifyRemoval(op);
-  if (listener)
-    listener->notifyOperationRemoved(op);
-  op->erase();
-}
-
 /// Notifies that the given constant `op` should be remove from this
 /// OperationFolder's internal bookkeeping.
 void OperationFolder::notifyRemoval(Operation *op) {
@@ -202,22 +192,18 @@ void OperationFolder::clear() {
 
 /// Get or create a constant using the given builder. On success this returns
 /// the constant operation, nullptr otherwise.
-Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
+Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect,
                                            Attribute value, Type type,
                                            Location loc) {
-  OpBuilder::InsertionGuard foldGuard(builder);
-
-  // Use the builder insertion block to find an insertion point for the
-  // constant.
-  auto *insertRegion =
-      getInsertionRegion(interfaces, builder.getInsertionBlock());
+  // Find an insertion point for the constant.
+  auto *insertRegion = getInsertionRegion(interfaces, block);
   auto &entry = insertRegion->front();
-  builder.setInsertionPoint(&entry, entry.begin());
+  rewriter.setInsertionPoint(&entry, entry.begin());
 
   // Get the constant map for the insertion region of this operation.
   auto &uniquedConstants = foldScopes[insertRegion];
-  Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
-                                              builder, value, type, loc);
+  Operation *constOp =
+      tryGetOrCreateConstant(uniquedConstants, dialect, value, type, loc);
   return constOp ? constOp->getResult(0) : Value();
 }
 
@@ -227,7 +213,7 @@ bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
 
 /// Tries to perform folding on the given `op`. If successful, populates
 /// `results` with the results of the folding.
-LogicalResult OperationFolder::tryToFold(OpBuilder &builder, Operation *op,
+LogicalResult OperationFolder::tryToFold(Operation *op,
                                          SmallVectorImpl<Value> &results) {
   SmallVector<Attribute, 8> operandConstants;
 
@@ -257,13 +243,13 @@ LogicalResult OperationFolder::tryToFold(OpBuilder &builder, Operation *op,
   // fold.
   SmallVector<OpFoldResult, 8> foldResults;
   if (failed(op->fold(operandConstants, foldResults)) ||
-      failed(processFoldResults(builder, op, results, foldResults)))
+      failed(processFoldResults(op, results, foldResults)))
     return success(updatedOpOperands);
   return success();
 }
 
 LogicalResult
-OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
+OperationFolder::processFoldResults(Operation *op,
                                     SmallVectorImpl<Value> &results,
                                     ArrayRef<OpFoldResult> foldResults) {
   // Check to see if the operation was just updated in place.
@@ -273,11 +259,9 @@ OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
 
   // Create a builder to insert new operations into the entry block of the
   // insertion region.
-  auto *insertRegion =
-      getInsertionRegion(interfaces, builder.getInsertionBlock());
+  auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
   auto &entry = insertRegion->front();
-  OpBuilder::InsertionGuard foldGuard(builder);
-  builder.setInsertionPoint(&entry, entry.begin());
+  rewriter.setInsertionPoint(&entry, entry.begin());
 
   // Get the constant map for the insertion region of this operation.
   auto &uniquedConstants = foldScopes[insertRegion];
@@ -300,9 +284,8 @@ OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
     // Check to see if there is a canonicalized version of this constant.
     auto res = op->getResult(i);
     Attribute attrRepl = foldResults[i].get<Attribute>();
-    if (auto *constOp =
-            tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
-                                   res.getType(), op->getLoc())) {
+    if (auto *constOp = tryGetOrCreateConstant(
+            uniquedConstants, dialect, attrRepl, res.getType(), op->getLoc())) {
       // Ensure that this constant dominates the operation we are replacing it
       // with. This may not automatically happen if the operation being folded
       // was inserted before the constant within the insertion block.
@@ -316,8 +299,10 @@ OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
     // If materialization fails, cleanup any operations generated for the
     // previous results and return failure.
     for (Operation &op : llvm::make_early_inc_range(
-             llvm::make_range(entry.begin(), builder.getInsertionPoint())))
-      eraseOp(&op);
+             llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) {
+      notifyRemoval(&op);
+      rewriter.eraseOp(&op);
+    }
 
     results.clear();
     return failure();
@@ -328,9 +313,10 @@ OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
 
 /// Try to get or create a new constant entry. On success this returns the
 /// constant operation value, nullptr otherwise.
-Operation *OperationFolder::tryGetOrCreateConstant(
-    ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
-    Attribute value, Type type, Location loc) {
+Operation *
+OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
+                                        Dialect *dialect, Attribute value,
+                                        Type type, Location loc) {
   // Check if an existing mapping already exists.
   auto constKey = std::make_tuple(dialect, value, type);
   Operation *&constOp = uniquedConstants[constKey];
@@ -338,7 +324,7 @@ Operation *OperationFolder::tryGetOrCreateConstant(
     return constOp;
 
   // If one doesn't exist, try to materialize one.
-  if (!(constOp = materializeConstant(dialect, builder, value, type, loc)))
+  if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
     return nullptr;
 
   // Check to see if the generated constant is in the expected dialect.
@@ -355,7 +341,8 @@ Operation *OperationFolder::tryGetOrCreateConstant(
   // If an existing operation in the new dialect already exists, delete the
   // materialized operation in favor of the existing one.
   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
-    eraseOp(constOp);
+    notifyRemoval(constOp);
+    rewriter.eraseOp(constOp);
     referencedDialects[existingOp].push_back(dialect);
     return constOp = existingOp;
   }

diff  --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
index 64ff4ce5b9e5..d1978b6099f0 100644
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
@@ -39,8 +39,9 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
       maybeDefiningOp ? maybeDefiningOp->getDialect()
                       : value.getParentRegion()->getParentOp()->getDialect();
   Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
-  Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
-                                              value.getType(), value.getLoc());
+  Value constant =
+      folder.getOrCreateConstant(b.getInsertionBlock(), valueDialect, constAttr,
+                                 value.getType(), value.getLoc());
   if (!constant)
     return failure();
 


        


More information about the Mlir-commits mailing list