[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Turn op creation into `IRRewrite` (PR #81759)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Feb 22 02:12:03 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81759

>From d15c439f166426a613ac0021a9101b774d544357 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 16 Feb 2024 15:21:48 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Turn op creation into `IRRewrite`

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure).

Until now, the dialect conversion kept track of "op creation" in separate internal data structures. This commit turns "op creation" into an `IRRewrite` that can be committed and rolled back just like any other rewrite. This commit simplifies the internal state of the dialect conversion.
---
 .../Transforms/Utils/DialectConversion.cpp    | 101 +++++++++++-------
 1 file changed, 63 insertions(+), 38 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index dec68048dc1d30..ae81dfa4d8303b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,17 +152,12 @@ namespace {
 /// This class contains a snapshot of the current conversion rewriter state.
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
-  RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
-                unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numErased)
-      : numCreatedOps(numCreatedOps),
-        numUnresolvedMaterializations(numUnresolvedMaterializations),
+  RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
+                unsigned numIgnoredOperations, unsigned numErased)
+      : numUnresolvedMaterializations(numUnresolvedMaterializations),
         numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
         numErased(numErased) {}
 
-  /// The current number of created operations.
-  unsigned numCreatedOps;
-
   /// The current number of unresolved materializations.
   unsigned numUnresolvedMaterializations;
 
@@ -303,7 +298,8 @@ class IRRewrite {
     // Operation rewrites
     MoveOperation,
     ModifyOperation,
-    ReplaceOperation
+    ReplaceOperation,
+    CreateOperation
   };
 
   virtual ~IRRewrite() = default;
@@ -376,7 +372,10 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
-    eraseBlock(block);
+    if (block->getParent())
+      eraseBlock(block);
+    else
+      delete block;
   }
 };
 
@@ -606,7 +605,7 @@ class OperationRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::MoveOperation &&
-           rewrite->getKind() <= Kind::ReplaceOperation;
+           rewrite->getKind() <= Kind::CreateOperation;
   }
 
 protected:
@@ -740,6 +739,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
   /// A boolean flag that indicates whether result types have changed or not.
   bool changedResults;
 };
+
+class CreateOperationRewrite : public OperationRewrite {
+public:
+  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                         Operation *op)
+      : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::CreateOperation;
+  }
+
+  void rollback() override;
+};
 } // namespace
 
 /// Return "true" if there is an operation rewrite that matches the specified
@@ -957,9 +969,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
 
-  /// Ordered vector of all of the newly created operations during conversion.
-  SmallVector<Operation *> createdOps;
-
   /// Ordered vector of all unresolved type conversion materializations during
   /// conversion.
   SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1144,6 +1153,15 @@ void ReplaceOperationRewrite::rollback() {
 
 void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
 
+void CreateOperationRewrite::rollback() {
+  for (Region &region : op->getRegions()) {
+    while (!region.getBlocks().empty())
+      region.getBlocks().remove(region.getBlocks().begin());
+  }
+  op->dropAllUses();
+  eraseOp(op);
+}
+
 void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
   for (Region &region : op->getRegions()) {
     for (Block &block : region.getBlocks()) {
@@ -1161,8 +1179,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
   // Remove any newly created ops.
   for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
     detachNestedAndErase(materialization.getOp());
-  for (auto *op : llvm::reverse(createdOps))
-    detachNestedAndErase(op);
 }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
@@ -1182,9 +1198,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
-                       rewrites.size(), ignoredOps.size(),
-                       eraseRewriter.erased.size());
+  return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
+                       ignoredOps.size(), eraseRewriter.erased.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1205,12 +1220,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     detachNestedAndErase(op);
   }
 
-  // Pop all of the newly created operations.
-  while (createdOps.size() != state.numCreatedOps) {
-    detachNestedAndErase(createdOps.back());
-    createdOps.pop_back();
-  }
-
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
@@ -1478,7 +1487,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
   });
   if (!previous.isSet()) {
     // This is a newly created op.
-    createdOps.push_back(op);
+    appendRewrite<CreateOperationRewrite>(op);
     return;
   }
   Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1979,13 +1988,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
   rewriter.replaceOp(op, replacementValues);
 
   // Recursively legalize any new constant operations.
-  for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
+  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
        i != e; ++i) {
-    Operation *cstOp = rewriterImpl.createdOps[i];
-    if (failed(legalize(cstOp, rewriter))) {
+    auto *createOp =
+        dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
+    if (!createOp)
+      continue;
+    if (failed(legalize(createOp->getOperation(), rewriter))) {
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
-                            cstOp->getName()));
+                            createOp->getOperation()->getName()));
       rewriterImpl.resetState(curState);
       return failure();
     }
@@ -2132,9 +2144,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     // blocks in regions created by this pattern will already be legalized later
     // on. If we haven't built the set yet, build it now.
     if (operationsToIgnore.empty()) {
-      auto createdOps = ArrayRef<Operation *>(impl.createdOps)
-                            .drop_front(state.numCreatedOps);
-      operationsToIgnore.insert(createdOps.begin(), createdOps.end());
+      for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
+           ++i) {
+        auto *createOp =
+            dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+        if (!createOp)
+          continue;
+        operationsToIgnore.insert(createOp->getOperation());
+      }
     }
 
     // If this operation should be considered for re-legalization, try it.
@@ -2152,8 +2169,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
     RewriterState &state, RewriterState &newState) {
-  for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
-    Operation *op = impl.createdOps[i];
+  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+    auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+    if (!createOp)
+      continue;
+    Operation *op = createOp->getOperation();
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(impl.logger,
                             "failed to legalize generated operation '{0}'({1})",
@@ -2583,10 +2603,15 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
     });
     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
   };
-  for (auto &r : rewriterImpl.rewrites)
-    if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
-      if (failed(rewrite->materializeLiveConversions(findLiveUser)))
+  // Note: `rewrites` may be reallocated as the loop is running.
+  for (int64_t i = 0; i < rewriterImpl.rewrites.size(); ++i) {
+    auto &rewrite = rewriterImpl.rewrites[i];
+    if (auto *blockTypeConversionRewrite =
+            dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
+      if (failed(blockTypeConversionRewrite->materializeLiveConversions(
+              findLiveUser)))
         return failure();
+  }
   return success();
 }
 



More information about the llvm-branch-commits mailing list