[Mlir-commits] [mlir] 9bdfa8d - [mlir][IR] Use Listener for IR callbacks in OperationFolder
Matthias Springer
llvmlistbot at llvm.org
Thu Feb 23 00:02:53 PST 2023
Author: Matthias Springer
Date: 2023-02-23T08:56:43+01:00
New Revision: 9bdfa8df0db21845b8e1d8fc0fc8b70dfe25f45d
URL: https://github.com/llvm/llvm-project/commit/9bdfa8df0db21845b8e1d8fc0fc8b70dfe25f45d
DIFF: https://github.com/llvm/llvm-project/commit/9bdfa8df0db21845b8e1d8fc0fc8b70dfe25f45d.diff
LOG: [mlir][IR] Use Listener for IR callbacks in OperationFolder
Remove the IR modification callbacks from `OperationFolder`. Instead, an optional `RewriterBase::Listener` can be specified.
* `processGeneratedConstants` => `notifyOperationCreated`
* `preReplaceAction` => `notifyOperationReplaced`
This simplifies the GreedyPatternRewriterDriver because we no longer need special handling for IR modifications due to op folding.
A folded operation is now enqueued on the GreedyPatternRewriteDriver's worklist if it was modified in-place. (There may be new patterns that apply after folding.)
Also fixes a bug in `TestOpInPlaceFold::fold`. The folder could previously be applied over and over and did not return a "null" OpFoldResult if the IR was not modified. (This is similar to a pattern that returns `success` without modifying IR; it can trigger an infinite loop in the GreedyPatternRewriteDriver.)
Differential Revision: https://reviews.llvm.org/D144463
Added:
Modified:
mlir/include/mlir/Transforms/FoldUtils.h
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Transforms/TestConstantFold.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 2e5a80cbce35d..ff1083724e59c 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -17,6 +17,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FoldInterfaces.h"
namespace mlir {
@@ -31,19 +32,14 @@ class Value;
/// generated along the way.
class OperationFolder {
public:
- OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
+ OperationFolder(MLIRContext *ctx, RewriterBase::Listener *listener = nullptr)
+ : interfaces(ctx), listener(listener) {}
/// Tries to perform folding on the given `op`, including unifying
/// deduplicated constants. If successful, replaces `op`'s uses with
- /// folded results, and returns success. `preReplaceAction` is invoked on `op`
- /// before it is replaced. 'processGeneratedConstants' is invoked for any new
- /// operations generated when folding. If the op was completely folded it is
+ /// folded results, and returns success. If the op was completely folded it is
/// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
- LogicalResult
- tryToFold(Operation *op,
- function_ref<void(Operation *)> processGeneratedConstants = nullptr,
- function_ref<void(Operation *)> preReplaceAction = nullptr,
- bool *inPlaceUpdate = nullptr);
+ LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate = nullptr);
/// Tries to fold a pre-existing constant operation. `constValue` represents
/// the value of the constant, and can be optionally passed if the value is
@@ -122,23 +118,23 @@ class OperationFolder {
using ConstantMap =
DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
+ /// Erase the given operation and notify the listener.
+ void eraseOp(Operation *op);
+
/// Returns true if the given operation is an already folded constant that is
/// owned by this folder.
bool isFolderOwnedConstant(Operation *op) const;
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
- LogicalResult tryToFold(
- OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
- function_ref<void(Operation *)> processGeneratedConstants = nullptr);
+ LogicalResult tryToFold(OpBuilder &builder, Operation *op,
+ SmallVectorImpl<Value> &results);
/// Try to process a set of fold results, generating constants as necessary.
/// Populates `results` on success, otherwise leaves it unchanged.
- LogicalResult
- processFoldResults(OpBuilder &builder, Operation *op,
- SmallVectorImpl<Value> &results,
- ArrayRef<OpFoldResult> foldResults,
- function_ref<void(Operation *)> processGeneratedConstants);
+ LogicalResult processFoldResults(OpBuilder &builder, Operation *op,
+ SmallVectorImpl<Value> &results,
+ ArrayRef<OpFoldResult> foldResults);
/// Try to get or create a new constant entry. On success this returns the
/// constant operation, nullptr otherwise.
@@ -156,6 +152,9 @@ class OperationFolder {
/// A collection of dialect folder interfaces.
DialectInterfaceCollection<DialectFoldInterface> interfaces;
+
+ /// An optional listener that is notified of all IR changes.
+ RewriterBase::Listener *listener = nullptr;
};
} // namespace mlir
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index fcfdbe4afab5f..22a488efb37ea 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -67,9 +67,7 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
// OperationFolder
//===----------------------------------------------------------------------===//
-LogicalResult OperationFolder::tryToFold(
- Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
- function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
+LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
if (inPlaceUpdate)
*inPlaceUpdate = false;
@@ -86,27 +84,26 @@ LogicalResult OperationFolder::tryToFold(
// Try to fold the operation.
SmallVector<Value, 8> results;
- OpBuilder builder(op);
- if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
+ OpBuilder builder(op, listener);
+ if (failed(tryToFold(builder, op, results)))
return failure();
// Check to see if the operation was just updated in place.
if (results.empty()) {
if (inPlaceUpdate)
*inPlaceUpdate = true;
+ if (listener)
+ listener->notifyOperationModified(op);
return success();
}
- // Constant folding succeeded. We will start replacing this op's uses and
- // erase this op. Invoke the callback provided by the caller to perform any
- // pre-replacement action.
- if (preReplaceAction)
- preReplaceAction(op);
-
- // Replace all of the result values and erase the operation.
+ // Constant folding succeeded. Replace all of the result values and erase the
+ // operation.
+ if (listener)
+ listener->notifyOperationReplaced(op, results);
for (unsigned i = 0, e = results.size(); i != e; ++i)
op->getResult(i).replaceAllUsesWith(results[i]);
- op->erase();
+ eraseOp(op);
return success();
}
@@ -144,8 +141,10 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
// If there is an existing constant, replace `op`.
if (folderConstOp) {
+ if (listener)
+ listener->notifyOperationReplaced(op, folderConstOp->getResults());
op->replaceAllUsesWith(folderConstOp);
- op->erase();
+ eraseOp(op);
return false;
}
@@ -163,6 +162,13 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
return true;
}
+void OperationFolder::eraseOp(Operation *op) {
+ notifyRemoval(op);
+ if (listener)
+ listener->notifyOperationRemoved(op);
+ op->erase();
+}
+
/// Notifies that the given constant `op` should be remove from this
/// OperationFolder's internal bookkeeping.
void OperationFolder::notifyRemoval(Operation *op) {
@@ -221,9 +227,8 @@ bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
-LogicalResult OperationFolder::tryToFold(
- OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
- function_ref<void(Operation *)> processGeneratedConstants) {
+LogicalResult OperationFolder::tryToFold(OpBuilder &builder, Operation *op,
+ SmallVectorImpl<Value> &results) {
SmallVector<Attribute, 8> operandConstants;
// If this is a commutative operation, move constants to be trailing operands.
@@ -252,16 +257,15 @@ LogicalResult OperationFolder::tryToFold(
// fold.
SmallVector<OpFoldResult, 8> foldResults;
if (failed(op->fold(operandConstants, foldResults)) ||
- failed(processFoldResults(builder, op, results, foldResults,
- processGeneratedConstants)))
+ failed(processFoldResults(builder, op, results, foldResults)))
return success(updatedOpOperands);
return success();
}
-LogicalResult OperationFolder::processFoldResults(
- OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
- ArrayRef<OpFoldResult> foldResults,
- function_ref<void(Operation *)> processGeneratedConstants) {
+LogicalResult
+OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
+ SmallVectorImpl<Value> &results,
+ ArrayRef<OpFoldResult> foldResults) {
// Check to see if the operation was just updated in place.
if (foldResults.empty())
return success();
@@ -312,20 +316,13 @@ LogicalResult OperationFolder::processFoldResults(
// If materialization fails, cleanup any operations generated for the
// previous results and return failure.
for (Operation &op : llvm::make_early_inc_range(
- llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
- notifyRemoval(&op);
- op.erase();
- }
+ llvm::make_range(entry.begin(), builder.getInsertionPoint())))
+ eraseOp(&op);
+
results.clear();
return failure();
}
- // Process any newly generated operations.
- if (processGeneratedConstants) {
- for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
- processGeneratedConstants(&*i);
- }
-
return success();
}
@@ -358,7 +355,7 @@ Operation *OperationFolder::tryGetOrCreateConstant(
// If an existing operation in the new dialect already exists, delete the
// materialized operation in favor of the existing one.
if (auto *existingOp = uniquedConstants.lookup(newKey)) {
- constOp->erase();
+ eraseOp(constOp);
referencedDialects[existingOp].push_back(dialect);
return constOp = existingOp;
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 15495efca09af..a5977ccec1182 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -127,7 +127,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
- : PatternRewriter(ctx), folder(ctx), config(config), matcher(patterns) {
+ : PatternRewriter(ctx), folder(ctx, this), config(config),
+ matcher(patterns) {
worklist.reserve(64);
// Apply a simple cost model based solely on pattern benefit.
@@ -156,9 +157,6 @@ bool GreedyPatternRewriteDriver::processWorklist() {
};
#endif
- // These are scratch vectors used in the folding loop below.
- SmallVector<Value, 8> originalOperands;
-
bool changed = false;
int64_t numRewrites = 0;
while (!worklist.empty() &&
@@ -197,34 +195,11 @@ bool GreedyPatternRewriteDriver::processWorklist() {
continue;
}
- // Collects all the operands and result uses of the given `op` into work
- // list. Also remove `op` and nested ops from worklist.
- originalOperands.assign(op->operand_begin(), op->operand_end());
- auto preReplaceAction = [&](Operation *op) {
- // Add the operands to the worklist for visitation.
- addOperandsToWorklist(originalOperands);
-
- // Add all the users of the result to the worklist so we make sure
- // to revisit them.
- for (auto result : op->getResults())
- for (auto *userOp : result.getUsers())
- addToWorklist(userOp);
-
- notifyOperationRemoved(op);
- };
-
- // Add the given operation to the worklist.
- auto collectOps = [this](Operation *op) { addToWorklist(op); };
-
// Try to fold this op.
- bool inPlaceUpdate;
- if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
- &inPlaceUpdate)))) {
+ if (succeeded(folder.tryToFold(op))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
-
changed = true;
- if (!inPlaceUpdate)
- continue;
+ continue;
}
// Try to match one of the patterns. The rewriter is automatically
@@ -465,7 +440,7 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
// Add all nested operations to the worklist in preorder.
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (!insertKnownConstant(op)) {
- worklist.push_back(op);
+ addToWorklist(op);
return WalkResult::advance();
}
return WalkResult::skip();
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index dc5f629610d90..d27349710c451 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1116,7 +1116,8 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
}
OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
- if (adaptor.getOp()) {
+ if (adaptor.getOp() && !(*this)->hasAttr("attr")) {
+ // The folder adds "attr" if not present.
(*this)->setAttr("attr", adaptor.getOp());
return getResult();
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e9816ebcc13e8..b4447eb8cc142 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1280,7 +1280,7 @@ def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> {
}
def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
- let arguments = (ins I32:$op, I32Attr:$attr);
+ let arguments = (ins I32:$op, OptionalAttr<I32Attr>:$attr);
let results = (outs I32);
let hasFolder = 1;
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 4bfbb3496ec3a..66c369d845bdf 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -93,8 +93,7 @@ struct FoldingPattern : public RewritePattern {
// (unchanged) operation result.
OperationFolder folder(op->getContext());
Value result = folder.create<TestOpInPlaceFold>(
- rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
- rewriter.getI32IntegerAttr(0));
+ rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
assert(result);
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp
index 1af923e014d52..9896cf372d269 100644
--- a/mlir/test/lib/Transforms/TestConstantFold.cpp
+++ b/mlir/test/lib/Transforms/TestConstantFold.cpp
@@ -13,8 +13,8 @@ using namespace mlir;
namespace {
/// Simple constant folding pass.
-struct TestConstantFold
- : public PassWrapper<TestConstantFold, OperationPass<>> {
+struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
+ public RewriterBase::Listener {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConstantFold)
StringRef getArgument() const final { return "test-constant-fold"; }
@@ -26,17 +26,22 @@ struct TestConstantFold
void foldOperation(Operation *op, OperationFolder &helper);
void runOnOperation() override;
+
+ void notifyOperationInserted(Operation *op) override {
+ existingConstants.push_back(op);
+ }
+ void notifyOperationRemoved(Operation *op) override {
+ auto it = llvm::find(existingConstants, op);
+ if (it != existingConstants.end())
+ existingConstants.erase(it);
+ }
};
} // namespace
void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
- auto processGeneratedConstants = [this](Operation *op) {
- existingConstants.push_back(op);
- };
-
// Attempt to fold the specified operation, including handling unused or
// duplicated constants.
- (void)helper.tryToFold(op, processGeneratedConstants);
+ (void)helper.tryToFold(op);
}
void TestConstantFold::runOnOperation() {
@@ -50,7 +55,7 @@ void TestConstantFold::runOnOperation() {
// folding are at the beginning. This creates somewhat of a linear ordering to
// the newly generated constants that matches the operation order and improves
// the readability of test cases.
- OperationFolder helper(&getContext());
+ OperationFolder helper(&getContext(), /*listener=*/this);
for (Operation *op : llvm::reverse(ops))
foldOperation(op, helper);
More information about the Mlir-commits
mailing list