[Mlir-commits] [mlir] d588e49 - [mlir][Transforms][NFC] Dialect conversion: Cache `UnresolvedMaterializationRewrite` (#108359)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 13 11:16:09 PDT 2024
Author: Matthias Springer
Date: 2024-09-13T20:16:05+02:00
New Revision: d588e49a324b3d6039c19f3108d722a8b9fcd96e
URL: https://github.com/llvm/llvm-project/commit/d588e49a324b3d6039c19f3108d722a8b9fcd96e
DIFF: https://github.com/llvm/llvm-project/commit/d588e49a324b3d6039c19f3108d722a8b9fcd96e.diff
LOG: [mlir][Transforms][NFC] Dialect conversion: Cache `UnresolvedMaterializationRewrite` (#108359)
The dialect conversion maintains a set of unresolved materializations
(`UnrealizedConversionCastOp`). Turn that set into a `DenseMap` that
maps from ops to `UnresolvedMaterializationRewrite *`. This improves
efficiency a bit, because an iteration over
`ConversionPatternRewriterImpl::rewrites` can be avoided.
Also delete some dead code.
Added:
Modified:
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b58a95c3baf70a..caea9e111afeda 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
- MaterializationKind kind = MaterializationKind::Target)
- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind) {}
+ MaterializationKind kind = MaterializationKind::Target);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
});
}
-/// Find the single rewrite object of the specified type and block among the
-/// given rewrites. In debug mode, asserts that there is mo more than one such
-/// object. Return "nullptr" if no object was found.
-template <typename RewriteTy, typename R>
-static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
- RewriteTy *result = nullptr;
- for (auto &rewrite : rewrites) {
- auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
- if (rewriteTy && rewriteTy->getBlock() == block) {
-#ifndef NDEBUG
- assert(!result && "expected single matching rewrite");
- result = rewriteTy;
-#else
- return rewriteTy;
-#endif // NDEBUG
- }
- }
- return result;
-}
-
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasErased(void *ptr) const { return erased.contains(ptr); }
- bool wasErased(OperationRewrite *rewrite) const {
- return wasErased(rewrite->getOperation());
- }
-
void notifyOperationErased(Operation *op) override { erased.insert(op); }
void notifyBlockErased(Block *block) override { erased.insert(block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;
- /// A set of all unresolved materializations.
- DenseSet<Operation *> unresolvedMaterializations;
+ /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
+ /// to the corresponding rewrite objects.
+ DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+ unresolvedMaterializations;
/// The current type converter, or nullptr if no type converter is currently
/// active.
@@ -1058,12 +1034,20 @@ void CreateOperationRewrite::rollback() {
op->erase();
}
+UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
+ const TypeConverter *converter, MaterializationKind kind)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ converterAndKind(converter, kind) {
+ rewriterImpl.unresolvedMaterializations[op] = this;
+}
+
void UnresolvedMaterializationRewrite::rollback() {
if (getMaterializationKind() == MaterializationKind::Target) {
for (Value input : op->getOperands())
rewriterImpl.mapping.erase(input);
}
- rewriterImpl.unresolvedMaterializations.erase(op);
+ rewriterImpl.unresolvedMaterializations.erase(getOperation());
op->erase();
}
@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- unresolvedMaterializations.insert(convertOp);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
@@ -1382,10 +1365,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
if (!newValue) {
// This result was dropped and no replacement value was provided.
- if (unresolvedMaterializations.contains(op)) {
- // Do not create another materializations if we are erasing a
- // materialization.
- continue;
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ if (unresolvedMaterializations.contains(castOp)) {
+ // Do not create another materializations if we are erasing a
+ // materialization.
+ continue;
+ }
}
// Materialize a replacement value "out of thin air".
@@ -2499,15 +2484,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Gather all unresolved materializations.
SmallVector<UnrealizedConversionCastOp> allCastOps;
- DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
- for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
- auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
- if (!mat)
- continue;
- if (rewriterImpl.eraseRewriter.wasErased(mat))
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+ &materializations = rewriterImpl.unresolvedMaterializations;
+ for (auto it : materializations) {
+ if (rewriterImpl.eraseRewriter.wasErased(it.first))
continue;
- allCastOps.push_back(mat->getOperation());
- rewriteMap[mat->getOperation()] = mat;
+ allCastOps.push_back(it.first);
}
// Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2502,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (config.buildMaterializations) {
IRRewriter rewriter(rewriterImpl.context, config.listener);
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
- auto it = rewriteMap.find(castOp.getOperation());
- assert(it != rewriteMap.end() && "inconsistent state");
+ auto it = materializations.find(castOp);
+ assert(it != materializations.end() && "inconsistent state");
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
return failure();
}
More information about the Mlir-commits
mailing list