[Mlir-commits] [mlir] [mlir][Transforms][NFC] Dialect conversion: Cache `UnresolvedMaterializationRewrite` (PR #108359)
Matthias Springer
llvmlistbot at llvm.org
Thu Sep 12 06:31:37 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/108359
>From 4cb4bcf54a267258f9d7129b347419897a164df7 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 12 Sep 2024 12:45:44 +0200
Subject: [PATCH] [mlir][Transforms][NFC] Dialect conversion: Cache
`UnresolvedMaterializationRewrite`
The dialect conversion already maintains a set of unresolved materializations (`UnrealizedConversionCastOp`). Turn that set into a map that maps from ops to `UnresolvedMaterializationRewrite *`. This improves efficiency a bit, because an iteration over `ConversionPatternRewriterImpl::rewrites` can be avoided.
Also delete some dead code.
---
.../Transforms/Utils/DialectConversion.cpp | 60 +++++++------------
1 file changed, 20 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b58a95c3baf70a..ed15b571f01883 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
- MaterializationKind kind = MaterializationKind::Target)
- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind) {}
+ MaterializationKind kind = MaterializationKind::Target);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
});
}
-/// Find the single rewrite object of the specified type and block among the
-/// given rewrites. In debug mode, asserts that there is mo more than one such
-/// object. Return "nullptr" if no object was found.
-template <typename RewriteTy, typename R>
-static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
- RewriteTy *result = nullptr;
- for (auto &rewrite : rewrites) {
- auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
- if (rewriteTy && rewriteTy->getBlock() == block) {
-#ifndef NDEBUG
- assert(!result && "expected single matching rewrite");
- result = rewriteTy;
-#else
- return rewriteTy;
-#endif // NDEBUG
- }
- }
- return result;
-}
-
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasErased(void *ptr) const { return erased.contains(ptr); }
- bool wasErased(OperationRewrite *rewrite) const {
- return wasErased(rewrite->getOperation());
- }
-
void notifyOperationErased(Operation *op) override { erased.insert(op); }
void notifyBlockErased(Block *block) override { erased.insert(block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;
- /// A set of all unresolved materializations.
- DenseSet<Operation *> unresolvedMaterializations;
+ /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
+ /// to the corresponding rewrite objects.
+ DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+ unresolvedMaterializations;
/// The current type converter, or nullptr if no type converter is currently
/// active.
@@ -1058,6 +1034,14 @@ void CreateOperationRewrite::rollback() {
op->erase();
}
+UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
+ const TypeConverter *converter, MaterializationKind kind)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ converterAndKind(converter, kind) {
+ rewriterImpl.unresolvedMaterializations[op] = this;
+}
+
void UnresolvedMaterializationRewrite::rollback() {
if (getMaterializationKind() == MaterializationKind::Target) {
for (Value input : op->getOperands())
@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- unresolvedMaterializations.insert(convertOp);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
@@ -2499,15 +2482,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Gather all unresolved materializations.
SmallVector<UnrealizedConversionCastOp> allCastOps;
- DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
- for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
- auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
- if (!mat)
- continue;
- if (rewriterImpl.eraseRewriter.wasErased(mat))
+ const DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+ &materializations = rewriterImpl.unresolvedMaterializations;
+ for (auto it : materializations) {
+ if (rewriterImpl.eraseRewriter.wasErased(it.first))
continue;
- allCastOps.push_back(mat->getOperation());
- rewriteMap[mat->getOperation()] = mat;
+ allCastOps.push_back(cast<UnrealizedConversionCastOp>(it.first));
}
// Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2500,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (config.buildMaterializations) {
IRRewriter rewriter(rewriterImpl.context, config.listener);
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
- auto it = rewriteMap.find(castOp.getOperation());
- assert(it != rewriteMap.end() && "inconsistent state");
+ auto it = materializations.find(castOp.getOperation());
+ assert(it != materializations.end() && "inconsistent state");
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
return failure();
}
More information about the Mlir-commits
mailing list