[Mlir-commits] [mlir] [mlir][Transforms] Track erased ops separately (PR #83051)
Matthias Springer
llvmlistbot at llvm.org
Mon Feb 26 11:46:11 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/83051
#83023 fixed a performance regression related to "ignored" ops. This broke some downstream projects that access ops after they were replaced (an API violation). This change restores the original behavior before #83023 (but without the performance regression), to give downstream users more time to fix their code.
>From 238322164f0c973318ddfcb9d66bcc5b05e7546c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 26 Feb 2024 19:41:55 +0000
Subject: [PATCH] [mlir][Transforms] Track erased ops separately
BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
.../Transforms/Utils/DialectConversion.cpp | 23 ++++++++++++++-----
1 file changed, 17 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4165e0a52428f9..f967e8352bf4c8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
- unsigned numErased)
+ unsigned numErased, unsigned numReplacedOps)
: numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
- numErased(numErased) {}
+ numErased(numErased), numReplacedOps(numReplacedOps) {}
/// The current number of rewrites performed.
unsigned numRewrites;
@@ -165,6 +165,9 @@ struct RewriterState {
/// The current number of erased operations/blocks.
unsigned numErased;
+
+ /// The current number of replaced ops that are scheduled for erasure.
+ unsigned numReplacedOps;
};
//===----------------------------------------------------------------------===//
@@ -954,6 +957,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// operation was ignored.
SetVector<Operation *> ignoredOps;
+ // A set of operations that were erased.
+ SetVector<Operation *> replacedOps;
+
/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
@@ -1152,7 +1158,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(rewrites.size(), ignoredOps.size(),
- eraseRewriter.erased.size());
+ eraseRewriter.erased.size(), replacedOps.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1165,6 +1171,9 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
while (eraseRewriter.erased.size() != state.numErased)
eraseRewriter.erased.pop_back();
+
+ while (replacedOps.size() != state.numReplacedOps)
+ replacedOps.pop_back();
}
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1228,9 +1237,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
return success();
}
+// TODO: This function is a misnomer. It does not actually check if `op` is in
+// `ignoredOps`.
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation or the parent operation is ignored.
- return ignoredOps.count(op->getParentOp()) || ignoredOps.count(op);
+ return ignoredOps.count(op->getParentOp()) || replacedOps.count(op);
}
void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
@@ -1479,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
- assert(!ignoredOps.contains(op) && "operation was already replaced");
+ assert(!replacedOps.contains(op) && "operation was already replaced");
// Track if any of the results changed, e.g. erased and replaced with null.
bool resultChanged = false;
@@ -1500,7 +1511,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
// Mark this operation as recursively ignored so that we don't need to
// convert any nested operations.
- ignoredOps.insert(op);
+ replacedOps.insert(op);
markNestedOpsIgnored(op);
}
More information about the Mlir-commits
mailing list