[Mlir-commits] [mlir] 7b66b5d - [mlir][Transforms] Track erased ops separately (#83051)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 26 11:49:57 PST 2024
Author: Matthias Springer
Date: 2024-02-26T20:49:52+01:00
New Revision: 7b66b5d6c2fadb1e9f8cfc8d1864d4109105001f
URL: https://github.com/llvm/llvm-project/commit/7b66b5d6c2fadb1e9f8cfc8d1864d4109105001f
DIFF: https://github.com/llvm/llvm-project/commit/7b66b5d6c2fadb1e9f8cfc8d1864d4109105001f.diff
LOG: [mlir][Transforms] Track erased ops separately (#83051)
#83023 fixed a performance regression related to "ignored" ops. This
broke some downstream projects that access ops after they were replaced
(an API violation). This change restores the original behavior before
#83023 (but without the performance regression), to give downstream
users more time to fix their code.
Added:
Modified:
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4165e0a52428f9..f967e8352bf4c8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
- unsigned numErased)
+ unsigned numErased, unsigned numReplacedOps)
: numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
- numErased(numErased) {}
+ numErased(numErased), numReplacedOps(numReplacedOps) {}
/// The current number of rewrites performed.
unsigned numRewrites;
@@ -165,6 +165,9 @@ struct RewriterState {
/// The current number of erased operations/blocks.
unsigned numErased;
+
+ /// The current number of replaced ops that are scheduled for erasure.
+ unsigned numReplacedOps;
};
//===----------------------------------------------------------------------===//
@@ -954,6 +957,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// operation was ignored.
SetVector<Operation *> ignoredOps;
+ // A set of operations that were erased.
+ SetVector<Operation *> replacedOps;
+
/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
@@ -1152,7 +1158,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(rewrites.size(), ignoredOps.size(),
- eraseRewriter.erased.size());
+ eraseRewriter.erased.size(), replacedOps.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1165,6 +1171,9 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
while (eraseRewriter.erased.size() != state.numErased)
eraseRewriter.erased.pop_back();
+
+ while (replacedOps.size() != state.numReplacedOps)
+ replacedOps.pop_back();
}
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1228,9 +1237,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
return success();
}
+// TODO: This function is a misnomer. It does not actually check if `op` is in
+// `ignoredOps`.
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation or the parent operation is ignored.
- return ignoredOps.count(op->getParentOp()) || ignoredOps.count(op);
+ return ignoredOps.count(op->getParentOp()) || replacedOps.count(op);
}
void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
@@ -1479,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
- assert(!ignoredOps.contains(op) && "operation was already replaced");
+ assert(!replacedOps.contains(op) && "operation was already replaced");
// Track if any of the results changed, e.g. erased and replaced with null.
bool resultChanged = false;
@@ -1500,7 +1511,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
// Mark this operation as recursively ignored so that we don't need to
// convert any nested operations.
- ignoredOps.insert(op);
+ replacedOps.insert(op);
markNestedOpsIgnored(op);
}
More information about the Mlir-commits
mailing list