[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Dialect conversion: Cache `UnresolvedMaterializationRewrite` (PR #108359)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Sep 12 03:49:32 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/108359

The dialect conversion maintains a set of unresolved materializations (`UnrealizedConversionCastOp`). Turn that set into a `DenseMap` that maps from ops to `UnresolvedMaterializationRewrite *`. This improves efficiency a bit, because an iteration over `ConversionPatternRewriterImpl::rewrites` can be avoided.

Also delete some dead code.

>From a9c69d1733662b3299bd3f4d41982422640dc034 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 12 Sep 2024 12:45:44 +0200
Subject: [PATCH] [mlir][Transforms][NFC] Dialect conversion: Cache
 `UnresolvedMaterializationRewrite`

The dialect conversion already maintains a set of unresolved materializations (`UnrealizedConversionCastOp`). Turn that set into a map that maps from ops to `UnresolvedMaterializationRewrite *`. This improves efficiency a bit, because an iteration over `ConversionPatternRewriterImpl::rewrites` can be avoided.

Also delete some dead code.
---
 .../Transforms/Utils/DialectConversion.cpp    | 60 +++++++------------
 1 file changed, 20 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b58a95c3baf70a..ed15b571f01883 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   UnresolvedMaterializationRewrite(
       ConversionPatternRewriterImpl &rewriterImpl,
       UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
-      MaterializationKind kind = MaterializationKind::Target)
-      : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-        converterAndKind(converter, kind) {}
+      MaterializationKind kind = MaterializationKind::Target);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
   });
 }
 
-/// Find the single rewrite object of the specified type and block among the
-/// given rewrites. In debug mode, asserts that there is mo more than one such
-/// object. Return "nullptr" if no object was found.
-template <typename RewriteTy, typename R>
-static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
-  RewriteTy *result = nullptr;
-  for (auto &rewrite : rewrites) {
-    auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
-    if (rewriteTy && rewriteTy->getBlock() == block) {
-#ifndef NDEBUG
-      assert(!result && "expected single matching rewrite");
-      result = rewriteTy;
-#else
-      return rewriteTy;
-#endif // NDEBUG
-    }
-  }
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     bool wasErased(void *ptr) const { return erased.contains(ptr); }
 
-    bool wasErased(OperationRewrite *rewrite) const {
-      return wasErased(rewrite->getOperation());
-    }
-
     void notifyOperationErased(Operation *op) override { erased.insert(op); }
 
     void notifyBlockErased(Block *block) override { erased.insert(block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// to modify/access them is invalid rewriter API usage.
   SetVector<Operation *> replacedOps;
 
-  /// A set of all unresolved materializations.
-  DenseSet<Operation *> unresolvedMaterializations;
+  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
+  /// to the corresponding rewrite objects.
+  DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+      unresolvedMaterializations;
 
   /// The current type converter, or nullptr if no type converter is currently
   /// active.
@@ -1058,6 +1034,14 @@ void CreateOperationRewrite::rollback() {
   op->erase();
 }
 
+UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
+    ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
+    const TypeConverter *converter, MaterializationKind kind)
+    : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+      converterAndKind(converter, kind) {
+  rewriterImpl.unresolvedMaterializations[op] = this;
+}
+
 void UnresolvedMaterializationRewrite::rollback() {
   if (getMaterializationKind() == MaterializationKind::Target) {
     for (Value input : op->getOperands())
@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  unresolvedMaterializations.insert(convertOp);
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
@@ -2499,15 +2482,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 
   // Gather all unresolved materializations.
   SmallVector<UnrealizedConversionCastOp> allCastOps;
-  DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
-  for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
-    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
-    if (!mat)
-      continue;
-    if (rewriterImpl.eraseRewriter.wasErased(mat))
+  const DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+      &materializations = rewriterImpl.unresolvedMaterializations;
+  for (auto it : materializations) {
+    if (rewriterImpl.eraseRewriter.wasErased(it.first))
       continue;
-    allCastOps.push_back(mat->getOperation());
-    rewriteMap[mat->getOperation()] = mat;
+    allCastOps.push_back(cast<UnrealizedConversionCastOp>(it.first));
   }
 
   // Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2500,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (config.buildMaterializations) {
     IRRewriter rewriter(rewriterImpl.context, config.listener);
     for (UnrealizedConversionCastOp castOp : remainingCastOps) {
-      auto it = rewriteMap.find(castOp.getOperation());
-      assert(it != rewriteMap.end() && "inconsistent state");
+      auto it = materializations.find(castOp.getOperation());
+      assert(it != materializations.end() && "inconsistent state");
       if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
         return failure();
     }



More information about the llvm-branch-commits mailing list