[Mlir-commits] [mlir] 9297b9f - [mlir][Transforms][NFC] Improve builder/listener API of OperationFolder
Matthias Springer
llvmlistbot at llvm.org
Wed Mar 22 01:24:57 PDT 2023
Author: Matthias Springer
Date: 2023-03-22T09:24:47+01:00
New Revision: 9297b9f8eeecc5ea6571cf45985ba77bc2960427
URL: https://github.com/llvm/llvm-project/commit/9297b9f8eeecc5ea6571cf45985ba77bc2960427
DIFF: https://github.com/llvm/llvm-project/commit/9297b9f8eeecc5ea6571cf45985ba77bc2960427.diff
LOG: [mlir][Transforms][NFC] Improve builder/listener API of OperationFolder
The constructor of `OperationFolder` takes a listener. Therefore, the remaining API should not take any builder/rewriters. This could lead to double notifications in case a listener is attached to the builder/rewriter.
As an internal cleanup, `OperationFolder` now has an `IRRewriter` instead of a `RewriterBase::Listener`. In most cases, `OperationFolder` no longer has to notify/deal with listeners. This is done by the rewriter.
Differential Revision: https://reviews.llvm.org/D146134
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/FoldUtils.h
mlir/lib/Transforms/SCCP.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/test/lib/Transforms/TestIntRangeInference.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 9c4790c03120..600ace488273 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -624,7 +624,9 @@ class RewriterBase : public OpBuilder {
protected:
/// Initialize the builder.
- explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx) {}
+ explicit RewriterBase(MLIRContext *ctx,
+ OpBuilder::Listener *listener = nullptr)
+ : OpBuilder(ctx, listener) {}
explicit RewriterBase(const OpBuilder &otherBuilder)
: OpBuilder(otherBuilder) {}
virtual ~RewriterBase();
@@ -648,7 +650,8 @@ class RewriterBase : public OpBuilder {
/// such as a `PatternRewriter`, is not available.
class IRRewriter : public RewriterBase {
public:
- explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
+ explicit IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr)
+ : RewriterBase(ctx, listener) {}
explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
};
diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index a6dc18369e77..2600da361496 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -32,8 +32,8 @@ class Value;
/// generated along the way.
class OperationFolder {
public:
- OperationFolder(MLIRContext *ctx, RewriterBase::Listener *listener = nullptr)
- : interfaces(ctx), listener(listener) {}
+ OperationFolder(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr)
+ : interfaces(ctx), rewriter(ctx, listener) {}
/// Tries to perform folding on the given `op`, including unifying
/// deduplicated constants. If successful, replaces `op`'s uses with
@@ -61,10 +61,11 @@ class OperationFolder {
/// Clear out any constants cached inside of the folder.
void clear();
- /// Get or create a constant using the given builder. On success this returns
- /// the constant operation, nullptr otherwise.
- Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
- Attribute value, Type type, Location loc);
+ /// Get or create a constant for use in the specified block. The constant may
+ /// be created in a parent block. On success this returns the constant
+ /// operation, nullptr otherwise.
+ Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value,
+ Type type, Location loc);
private:
/// This map keeps track of uniqued constants by dialect, attribute, and type.
@@ -74,29 +75,25 @@ class OperationFolder {
using ConstantMap =
DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
- /// Erase the given operation and notify the listener.
- void eraseOp(Operation *op);
-
/// Returns true if the given operation is an already folded constant that is
/// owned by this folder.
bool isFolderOwnedConstant(Operation *op) const;
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
- LogicalResult tryToFold(OpBuilder &builder, Operation *op,
- SmallVectorImpl<Value> &results);
+ LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value> &results);
- /// Try to process a set of fold results, generating constants as necessary.
- /// Populates `results` on success, otherwise leaves it unchanged.
- LogicalResult processFoldResults(OpBuilder &builder, Operation *op,
+ /// Try to process a set of fold results. Populates `results` on success,
+ /// otherwise leaves it unchanged.
+ LogicalResult processFoldResults(Operation *op,
SmallVectorImpl<Value> &results,
ArrayRef<OpFoldResult> foldResults);
/// Try to get or create a new constant entry. On success this returns the
/// constant operation, nullptr otherwise.
Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,
- Dialect *dialect, OpBuilder &builder,
- Attribute value, Type type, Location loc);
+ Dialect *dialect, Attribute value,
+ Type type, Location loc);
/// A mapping between an insertion region and the constants that have been
/// created within it.
@@ -109,8 +106,8 @@ class OperationFolder {
/// A collection of dialect folder interfaces.
DialectInterfaceCollection<DialectFoldInterface> interfaces;
- /// An optional listener that is notified of all IR changes.
- RewriterBase::Listener *listener = nullptr;
+ /// A rewriter that performs all IR modifications.
+ IRRewriter rewriter;
};
} // namespace mlir
diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index b32173a3a981..14435b37acc9 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -51,9 +51,9 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver,
// Attempt to materialize a constant for the given value.
Dialect *dialect = latticeValue.getConstantDialect();
- Value constant = folder.getOrCreateConstant(builder, dialect,
- latticeValue.getConstantValue(),
- value.getType(), value.getLoc());
+ Value constant = folder.getOrCreateConstant(
+ builder.getInsertionBlock(), dialect, latticeValue.getConstantValue(),
+ value.getType(), value.getLoc());
if (!constant)
return failure();
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 22a488efb37e..827c0ad4290b 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -84,26 +84,25 @@ LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
// Try to fold the operation.
SmallVector<Value, 8> results;
- OpBuilder builder(op, listener);
- if (failed(tryToFold(builder, op, results)))
+ if (failed(tryToFold(op, results)))
return failure();
// Check to see if the operation was just updated in place.
if (results.empty()) {
if (inPlaceUpdate)
*inPlaceUpdate = true;
- if (listener)
- listener->notifyOperationModified(op);
+ if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>(
+ rewriter.getListener())) {
+ // Folding API does not notify listeners, so we have to notify manually.
+ rewriteListener->notifyOperationModified(op);
+ }
return success();
}
// Constant folding succeeded. Replace all of the result values and erase the
// operation.
- if (listener)
- listener->notifyOperationReplaced(op, results);
- for (unsigned i = 0, e = results.size(); i != e; ++i)
- op->getResult(i).replaceAllUsesWith(results[i]);
- eraseOp(op);
+ notifyRemoval(op);
+ rewriter.replaceOp(op, results);
return success();
}
@@ -141,10 +140,8 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
// If there is an existing constant, replace `op`.
if (folderConstOp) {
- if (listener)
- listener->notifyOperationReplaced(op, folderConstOp->getResults());
- op->replaceAllUsesWith(folderConstOp);
- eraseOp(op);
+ notifyRemoval(op);
+ rewriter.replaceOp(op, folderConstOp->getResults());
return false;
}
@@ -162,13 +159,6 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
return true;
}
-void OperationFolder::eraseOp(Operation *op) {
- notifyRemoval(op);
- if (listener)
- listener->notifyOperationRemoved(op);
- op->erase();
-}
-
/// Notifies that the given constant `op` should be remove from this
/// OperationFolder's internal bookkeeping.
void OperationFolder::notifyRemoval(Operation *op) {
@@ -202,22 +192,18 @@ void OperationFolder::clear() {
/// Get or create a constant using the given builder. On success this returns
/// the constant operation, nullptr otherwise.
-Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
+Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect,
Attribute value, Type type,
Location loc) {
- OpBuilder::InsertionGuard foldGuard(builder);
-
- // Use the builder insertion block to find an insertion point for the
- // constant.
- auto *insertRegion =
- getInsertionRegion(interfaces, builder.getInsertionBlock());
+ // Find an insertion point for the constant.
+ auto *insertRegion = getInsertionRegion(interfaces, block);
auto &entry = insertRegion->front();
- builder.setInsertionPoint(&entry, entry.begin());
+ rewriter.setInsertionPoint(&entry, entry.begin());
// Get the constant map for the insertion region of this operation.
auto &uniquedConstants = foldScopes[insertRegion];
- Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
- builder, value, type, loc);
+ Operation *constOp =
+ tryGetOrCreateConstant(uniquedConstants, dialect, value, type, loc);
return constOp ? constOp->getResult(0) : Value();
}
@@ -227,7 +213,7 @@ bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
-LogicalResult OperationFolder::tryToFold(OpBuilder &builder, Operation *op,
+LogicalResult OperationFolder::tryToFold(Operation *op,
SmallVectorImpl<Value> &results) {
SmallVector<Attribute, 8> operandConstants;
@@ -257,13 +243,13 @@ LogicalResult OperationFolder::tryToFold(OpBuilder &builder, Operation *op,
// fold.
SmallVector<OpFoldResult, 8> foldResults;
if (failed(op->fold(operandConstants, foldResults)) ||
- failed(processFoldResults(builder, op, results, foldResults)))
+ failed(processFoldResults(op, results, foldResults)))
return success(updatedOpOperands);
return success();
}
LogicalResult
-OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
+OperationFolder::processFoldResults(Operation *op,
SmallVectorImpl<Value> &results,
ArrayRef<OpFoldResult> foldResults) {
// Check to see if the operation was just updated in place.
@@ -273,11 +259,9 @@ OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
// Create a builder to insert new operations into the entry block of the
// insertion region.
- auto *insertRegion =
- getInsertionRegion(interfaces, builder.getInsertionBlock());
+ auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
auto &entry = insertRegion->front();
- OpBuilder::InsertionGuard foldGuard(builder);
- builder.setInsertionPoint(&entry, entry.begin());
+ rewriter.setInsertionPoint(&entry, entry.begin());
// Get the constant map for the insertion region of this operation.
auto &uniquedConstants = foldScopes[insertRegion];
@@ -300,9 +284,8 @@ OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
// Check to see if there is a canonicalized version of this constant.
auto res = op->getResult(i);
Attribute attrRepl = foldResults[i].get<Attribute>();
- if (auto *constOp =
- tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
- res.getType(), op->getLoc())) {
+ if (auto *constOp = tryGetOrCreateConstant(
+ uniquedConstants, dialect, attrRepl, res.getType(), op->getLoc())) {
// Ensure that this constant dominates the operation we are replacing it
// with. This may not automatically happen if the operation being folded
// was inserted before the constant within the insertion block.
@@ -316,8 +299,10 @@ OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
// If materialization fails, cleanup any operations generated for the
// previous results and return failure.
for (Operation &op : llvm::make_early_inc_range(
- llvm::make_range(entry.begin(), builder.getInsertionPoint())))
- eraseOp(&op);
+ llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) {
+ notifyRemoval(&op);
+ rewriter.eraseOp(&op);
+ }
results.clear();
return failure();
@@ -328,9 +313,10 @@ OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
/// Try to get or create a new constant entry. On success this returns the
/// constant operation value, nullptr otherwise.
-Operation *OperationFolder::tryGetOrCreateConstant(
- ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
- Attribute value, Type type, Location loc) {
+Operation *
+OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
+ Dialect *dialect, Attribute value,
+ Type type, Location loc) {
// Check if an existing mapping already exists.
auto constKey = std::make_tuple(dialect, value, type);
Operation *&constOp = uniquedConstants[constKey];
@@ -338,7 +324,7 @@ Operation *OperationFolder::tryGetOrCreateConstant(
return constOp;
// If one doesn't exist, try to materialize one.
- if (!(constOp = materializeConstant(dialect, builder, value, type, loc)))
+ if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
return nullptr;
// Check to see if the generated constant is in the expected dialect.
@@ -355,7 +341,8 @@ Operation *OperationFolder::tryGetOrCreateConstant(
// If an existing operation in the new dialect already exists, delete the
// materialized operation in favor of the existing one.
if (auto *existingOp = uniquedConstants.lookup(newKey)) {
- eraseOp(constOp);
+ notifyRemoval(constOp);
+ rewriter.eraseOp(constOp);
referencedDialects[existingOp].push_back(dialect);
return constOp = existingOp;
}
diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
index 64ff4ce5b9e5..d1978b6099f0 100644
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
@@ -39,8 +39,9 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
maybeDefiningOp ? maybeDefiningOp->getDialect()
: value.getParentRegion()->getParentOp()->getDialect();
Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
- Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
- value.getType(), value.getLoc());
+ Value constant =
+ folder.getOrCreateConstant(b.getInsertionBlock(), valueDialect, constAttr,
+ value.getType(), value.getLoc());
if (!constant)
return failure();
More information about the Mlir-commits
mailing list